--- /dev/null
+[SamPolicy]
+preset=CCD_FOR_OO
--- /dev/null
+# External code: Android NN API
+/ONE/compiler/ann-api/include/NeuralNetworks.h
+/ONE/compiler/ann-ref
+
+# Eigen
+/ONE/compiler/nnc/backends/soft_backend/code_snippets/eigen.def
+
+# Test codes
+/ONE/tests
+
+# Flatbuffers generated
+/ONE/runtime/onert/frontend/circle_schema/include/circle_schema_generated.h
+/ONE/runtime/onert/frontend/tflite/src/tflite_schema_generated.h
+
+# External code: Android NN API
+/ONE/runtime/nnapi-header/include/NeuralNetworks.h
+/ONE/runtime/nnapi-header/include/NeuralNetworksExtensions.h
+
+# External code: Tensorflow lite
+/ONE/runtime/libs/nnapi
+/ONE/runtime/libs/profiling
+
+# External code: 3rd party
+/ONE/runtime/3rdparty
+
+# External code: compute libraries
+/ONE/compute
+
+# Experimental subprojects not for release
+/ONE/runtime/contrib
+
+# Downloaded externals
+/ONE/externals
version: 2
test:
- - name: NN Runtime
+ - name: NN_Runtime
testCaseLanguage: CPP
testFW: GTEST
testCaseFolder:
- - ./compute/test/cker
- - ./runtime/onert/core/src/backend/basic
- - ./runtime/onert/frontend/nnapi
- - ./runtime/onert/test/core/compiler
- - ./runtime/onert/test/core/exec
- - ./runtime/onert/test/core/interp
- - ./runtime/onert/test/graph
- - ./runtime/onert/test/graph/operand
- - ./runtime/onert/test/graph/operation
- - ./runtime/onert/test/graph/verifier
- - ./runtime/onert/test/ir
- - ./runtime/onert/test/util
- - ./tests/nnapi/src
- - ./tests/nnfw_api/src
- - ./tests/tools/tflite_run/src
+ - /compute/test/cker
+ - /runtime/onert/core/src/backend/basic
+ - /runtime/onert/frontend/nnapi
+ - /runtime/onert/test/core/compiler
+ - /runtime/onert/test/core/exec
+ - /runtime/onert/test/core/interp
+ - /runtime/onert/test/graph
+ - /runtime/onert/test/graph/operand
+ - /runtime/onert/test/graph/operation
+ - /runtime/onert/test/graph/verifier
+ - /runtime/onert/test/ir
+ - /runtime/onert/test/util
+ - /tests/nnapi/src
+ - /tests/nnfw_api/src
+ - /tests/tools/tflite_run/src
testFile:
- extension: cpp
any: true
- extension: cc
any: true
-
+ - excludes :
+ - DepthwiseConv2D.cc
+ - ArgMinMax.cc
+ - AveragePool2D.cc
+ - Concat.cc
+ - DepthToSpace.cc
+ - DepthwiseConv2D.cc
+ - Fill.cc
+ - If.cc
+ - Pad.cc
+ - Reduce.cc
+ - ResizeBilinear.c
+ - Slice.cc
+ - Softmax.cc
+ - While.cc
testCase:
- condition:
- functionName:
starts:
- TEST
+ - excludes :
+ - Verifier.dag_checker
+ - graph_operand_LayoutSet.layout_set_operators
+ - InterpExecutorTest.executeTwoStep
+ - InterpExecutorTest.execute
+ - InterpExecutorTest.setOutput
+ - InterpExecutorTest.create_empty
+ - InterpExecutorTest.setOutputForUnspecifiedDimensions
+ - InterpExecutorTest.setInputForUnspecifiedDimensions
+ - InterpExecutorTest.setInput
+ - InterpExecutorTest.create_simple
+ - ExecTime.structure
+ - ExecTime.roundtrip_ok
+ - SchedulerTest.branched_graph_profiling_mode
+ - SchedulerTestWithExecutorParam.straight_graph_known_exec_time
+ - SchedulerTestWithExecutorParam.branched_graph_known_exec_time
+ - TFLite_test_case.simple_test
+ - ExecInstance.simple
+ - ExecInstance.twoExecution
+ - ExecInstance.twoCompile
+ - ExecInstance.async
+ - ExecInstance.twoThreads
+ - graph_operand_usedef.usedef_test
+ - Graph.inputs_and_outputs
+ - nnfw_create_session.Test_001
+ - nnfw_create_session.Negative_001
+ - WICPlanner.claim_release_test
+ - BumpPlanner.claim_test
+ - Allocator.allocate_test
+ - FirstFitPlanner.claim_release_test
+ - graph_operation_setIO.operation_setIO_concat
+ - graph_operation_setIO.operation_setIO_conv
+ - ValidationTest.neg_prepare_001
+ - ValidationTestOneOpModelLoaded.prepare_001
+ - graph_OperandIndexSequence.replace
+ - graph_OperandIndexSequence.append
+ - MODEL.model_build
+ - graph_operation_Set.operation_test
+ - graph_operand_Set.set_test
+ - ValidationTestSessionCreated.neg_load_session_001
+ - ValidationTestSessionCreated.load_session_001
+ - ShapeInference.Pool2DNodeExplicit
+ - ShapeInference.Elementwise
+ - ShapeInference.Concat
+ - ShapeInference.Pool2DNodeSame
+ - ShapeInference.IncorrectElementwise
+ - ShapeInference.Conv2D
+ - ShapeInference.Pool2DNodeValid
+ - ShapeInference.FullyConnected
+ - ShapeInference.DepthwiseConv2D
+ - ObjectManager.non_const_iterate
+ - ObjectManager.const_iterate
+ - ObjectManager.emplace
+ - ObjectManager.remove_2
+ - ObjectManager.remove_1
+ - ObjectManager.push
+ - Index.index_test
negativeTestCase:
- condition:
positiveTestCase:
- condition:
- inverse: negativeTestCase
+
+ - name: NN_Compiler
+ testCaseLanguage: CPP
+ testFW: GTEST
+ testCaseFolder:
+ - /compiler/angkor
+ - /compiler/arser
+ - /compiler/circle2circle
+ - /compiler/circle-quantizer
+ - /compiler/circle-tensordump
+ - /compiler/circlechef
+ - /compiler/circledump
+ - /compiler/crew
+ - /compiler/cwrap
+ - /compiler/foder
+ - /compiler/hermes
+ - /compiler/hermes-std
+ - /compiler/loco
+ - /compiler/locomotiv
+ - /compiler/locop
+ - /compiler/logo
+ - /compiler/logo-core
+ - /compiler/luci
+ - /compiler/luci-interpreter
+ - /compiler/luci-eval-driver
+ - /compiler/luci-pass-value-test
+ - /compiler/luci-value-test
+ - /compiler/mio-circle
+ - /compiler/mio-tflite
+ - /compiler/oops
+ - /compiler/pepper-assert
+ - /compiler/pepper-str
+ - /compiler/pepper-strcast
+ - /compiler/pp
+ - /compiler/record-minmax
+ - /compiler/safemain
+ - /compiler/souschef
+ - /compiler/tflchef
+ - /compiler/tflite2circle
+ - /compiler/vconone
+
+ testFile:
+ - extension: .test.cpp
+ any: true
+
+ testCase:
+ - condition:
+ - functionName:
+ starts:
+ - TEST
+ - excludes :
+ - ConstantFolding.const_relu_to_const
+ - ConstantFolding.const_relu_to_concat
+ - ADT_TENSOR_OVERLAY.access
+ - ADT_TENSOR_OVERLAY.ctor
+ - ADT_TENSOR_OVERLAY.read
+ - NodeExecution_BiasEncode.f32
+ - NodeExecution_BiasEncode.s32
+ - NodeExecution_EltwiseDiv.f32
+ - CircleLogicalOrTest.constructor_P
+ - NodeExecution_TensorConcat.f32_2
+ - NodeExecution_TensorConcat.f32
+ - CircleShapeInferenceRuleTest.avgpool2d_valid
+ - CircleShapeInferenceRuleTest.TFAdd_shapeinf_different
+ - CircleShapeInferenceRuleTest.minimal_with_CircleRelu
+ - CircleShapeInferenceRuleTest.CircleTranspose_simple
+ - CircleShapeInferenceRuleTest.avgpool2d_same
+ - CircleConv2Dest.constructor_P
+ - ADT_TENSOR_BUFFER.access
+ - ADT_TENSOR_BUFFER.ctor
+ - CircleRelu6Test.constructor_P
+ - Circle2CircleTest.NoArg_NEG
+ - CircleInstanceNormTest.constructor
+ - ADT_KERNEL_INDEX_ENUMERATOR.iterate_full_range
+ - ADT_TENSOR_INDEX_ENUMERATOR.iterate_full_range
+ - CirclePadTest.constructor_P
+ - ADT_KERNEL_KERNEL_NHWC_LAYOUT.n_increment
+ - ADT_KERNEL_KERNEL_NHWC_LAYOUT.col_increment
+ - ADT_KERNEL_KERNEL_NHWC_LAYOUT.ch_increment
+ - ADT_KERNEL_KERNEL_NHWC_LAYOUT.row_increment
+ - ADT_TENSOR_LEXICAL_LAYOUT.last
+ - ADT_TENSOR_LEXICAL_LAYOUT.lexical_first
+ - ADT_TENSOR_LEXICAL_LAYOUT.lexical_middle
+ - FeatureShapeTest.settet_and_getter
+ - FeatureShapeTest.default_constructor
+ - INDENTED_STRING_BUILDER.usage
+ - NodeExecution_Fixed_Reduce_Mean.f32_1
+ - NodeExecution_Fixed_Reduce_Mean.f32_0
+ - CircleAbsTest.constructor
+ - CircleMaximumTest.constructor_P
+ - FORMAT.simple_string
+ - FORMAT.concat_rvalue
+ - FORMAT.concat_lvalue
+ - FORMAT.simple_number
+ - ADT_KERNEL_BUFFER.ctor
+ - ADT_KERNEL_BUFFER.access
+ - ADT_TENSOR_SHAPE.num_elements_rank_0
+ - ADT_TENSOR_SHAPE.squeeze_neg_0
+ - ADT_TENSOR_SHAPE.num_elements_zero
+ - ADT_TENSOR_SHAPE.copy
+ - ADT_TENSOR_SHAPE.eq_negative_on_unmatched_dim
+ - ADT_TENSOR_SHAPE.num_elements_nulldim
+ - ADT_TENSOR_SHAPE.eq_positive
+ - ADT_TENSOR_SHAPE.squeeze_pos
+ - ADT_TENSOR_SHAPE.resize
+ - ADT_TENSOR_SHAPE.ctor_initializer_list
+ - ADT_TENSOR_SHAPE.squeeze_neg
+ - ADT_TENSOR_SHAPE.squeeze_nested
+ - ADT_TENSOR_SHAPE.num_elements_nonzero
+ - ADT_TENSOR_SHAPE.eq_negative_on_unmatched_rank
+ - ADT_TENSOR_SHAPE.dim
+ - ADT_TENSOR_SHAPE.ctor
+ - GraphBuilderTest.Usecase_000
+ - QueueTest.take
+ - MultiDialectShapeInferenceRuleTest.test1
+ - AlgorithmTest.postorder_traversal_incomplte_graph
+ - AlgorithmTest.active_nodes
+ - AlgorithmTest.postorder_traversal_visit_once
+ - AlgorithmTest.postorder_traversal
+ - CircleSquaredDifferenceTest.constructor_P
+ - NodeShapeTest.feature_shape_constructor
+ - NodeShapeTest.filter_shape_constructor
+ - NodeShapeTest.default_constructor
+ - NodeShapeTest.copy_constructible
+ - NodeShapeTest.tensor_shape_constructor
+ - NodeShapeTest.dwfilter_shape_constructor
+ - NodeShapeTest.bias_shape_constructor
+ - ADT_KERNEL_KERNEL_NCHW_LAYOUT.n_increment
+ - ADT_KERNEL_KERNEL_NCHW_LAYOUT.col_increment
+ - ADT_KERNEL_KERNEL_NCHW_LAYOUT.row_increment
+ - ADT_KERNEL_KERNEL_NCHW_LAYOUT.ch_increment
+ - CircleEqualTest.constructor_P
+ - VerifierTest.valid_error_reporter
+ - VerifierTest.valid_minimal
+ - DataTypeTraitsTest.FLOAT32
+ - NodeExecution_EltwiseSub.f32
+ - NodeExecution_FeatureCodec.s32
+ - NodeExecution_FeatureCodec.f32
+ - ADT_TENSOR_INDEX.operator_add
+ - ADT_TENSOR_INDEX.ctor_initializer_list
+ - ADT_TENSOR_INDEX.fill
+ - ADT_TENSOR_INDEX.operator_eqaul
+ - ADT_TENSOR_INDEX.resize
+ - ADT_TENSOR_INDEX.operator_add_different_size
+ - ADT_TENSOR_INDEX.at
+ - ADT_TENSOR_INDEX.ctor
+ - ADT_TENSOR_INDEX.copy
+ - ADT_KERNEL_OVERLAY.access
+ - ADT_KERNEL_OVERLAY.read
+ - ADT_KERNEL_OVERLAY.ctor
+ - BiasShapeTest.default_constructor
+ - FildesTest.destructor
+ - FildesTest.value_constructor
+ - FildesTest.move_constructor
+ - FildesTest.default_constructor
+ - CircleGatherTest.constructor
+ - LinearV1FormatterTest.node_summary_builder_composition
+ - LinearV1FormatterTest.user_defined_node_summary_builder
+ - LinearV1FormatterTest.simple
+ - SourceTest.construct
+ - SourceTest.macro
+ - CircleFullyConnectedTest.constructor
+ - ADT_FEATURE_OVERLAY.read
+ - ADT_FEATURE_OVERLAY.access
+ - ADT_FEATURE_OVERLAY.ctor
+ - ContextTest.constructor
+ - CircleDivTest.constructor_P
+ - NodeExecution_Reshape.f32
+ - MultiDialectTypeInferenceRuleTest.test1
+ - CanonicalTypeInferenceRuleTest.relu6
+ - TypeInferenceTest.framework
+ - CanonicalTypeInferenceRuleTest.tensor_broadcast
+ - CanonicalTypeInferenceRuleTest.minimal
+ - PermutingDecoderTest.feature
+ - PemutationTest.feature
+ - PermutingEncoderTest.depthwisefilter_init
+ - PermutingDecoderTest.filter
+ - PermutingEncoderTest.depthwise_filter
+ - PermutingEncoderTest.filter
+ - PermutingEncoderTest.feature_clone
+ - PermutingEncoderTest.feature
+ - PermutingDecoderTest.depthwise_filter
+ - PemutationTest.depthwise_filter
+ - PermutingDecoderTest.feature_clone
+ - PemutationTest.filter
+ - PadTest.default_constructor_2D
+ - NodeDomain.as_annotation
+ - CirclePackTest.constructor
+ - ADT_TENSOR_LAYOUT.move
+ - ADT_TENSOR_LAYOUT.ctor
+ - ADT_TENSOR_LAYOUT.copy
+ - DepthwiseFilterShapeTest.settet_and_getter
+ - DepthwiseFilterShapeTest.default_constructor
+ - CircleTypeInferenceRuleTest.minimal_with_CircleRelu
+ - GenericNodeSummaryBuilderTest.simple
+ - LogoPassTests.pass_name_over_unnamed_pass
+ - LogoPassTests.pass_name_over_named_pass
+ - CircleReluTest.constructor_P
+ - PaddingNDTest.default_constructor_ND
+ - TensorShapeTest.copy
+ - TensorShapeTest.rank
+ - TensorShapeTest.element_count
+ - TensorShapeTest.dim
+ - TensorShapeTest.rank_update
+ - TensorShapeTest.default_constructor
+ - TensorShapeTest.initializer_list_constructor
+ - DepthwiseFilterIndexTest.settet_and_getter
+ - DepthwiseFilterIndexTest.default_constructor
+ - MemoryTest.make_unique
+ - AnnotatedItemTest.annotation
+ - NodeExecution_DepthwiseFilterEncode.f32
+ - CircleBatchToSpaceNDTest.constructor
+ - WindowTest.setter_and_getter_2D
+ - WindowTest.default_constructor_2D
+ - NodeExecution_Tanh.f32
+ - MessageBufferTest.pass_constructed_message_on_descturction
+ - NodeExecution_TensorBroadcast.f32
+ - CircleSubTest.constructor_P
+ - NodeExecution_AvgPool2D.f32_1x3x3x1_calculation
+ - NodeExecution_AvgPool2D.f32_1x4x4x1_calculation
+ - NodeExecution_Conv2D.f32_multiple_channel
+ - NodeExecution_Conv2D.f32_1x5x5x1_calculation
+ - NodeExecution_Conv2D.with_padding
+ - ADT_FEATURE_HWC_LAYOUT.W_increase
+ - ADT_FEATURE_HWC_LAYOUT.C_increase
+ - ADT_FEATURE_HWC_LAYOUT.H_increase
+ - SimplifyDomainConversionPass.FilterEncode_FilterDecode_equal_perms
+ - SimplifyDomainConversionPass.FilterEncode_FilterDecode_different_perms
+ - CircleDialectTest.get_N
+ - CircleDialectTest.get_P
+ - LINEAR_DOCUMENT.line
+ - LINEAR_DOCUMENT.lines
+ - NodeExecution_Push.f32
+ - NodeExecution_Push.s32
+ - NodeExecution_DepthwiseConv2D.f32_random_valid
+ - NodeExecution_Pad.tensor_constant_pad_6_dim
+ - NodeExecution_Pad.tensor_constant_pad_1_dim
+ - NodeExecution_Pad.tensor_constant_pad_4_dim
+ - DepthwiseConv2DTest.constructor
+ - ConstGenTest.constructor_s32
+ - TransposedConv2DTest.constructor
+ - PullTest.shape
+ - MatrixDecodeTest.constructor
+ - FilterEncodeTest.constructor
+ - AvgPool2DTest.constructor
+ - Reshape_Fixed_Test.shape
+ - TensorConcatTest.constructor
+ - EltwiseSqrtTest.constructor
+ - TensorBiasAddTest.alias
+ - EltwiseSubTest.constructor
+ - TensorBroadcastTest.mapping
+ - PullTest.constructor
+ - PushTest.shape
+ - MaxPool2DTest.pad
+ - EltwiseMulTest.constructor
+ - DepthwiseFilterEncodeTest.constructor
+ - ForwardTest.constructor
+ - MaxPool2DTest.constructor
+ - TransposeTest.perm
+ - MatMulTest.constructor
+ - FeatureBiasAddTest.constructor
+ - TensorBroadcastTest.constructor
+ - FeatureEncodeTest.constructor
+ - MatrixEncodeTest.constructor
+ - ReLUTest.constructor
+ - BiasEncodeTest.constructor
+ - FilterDecodeTest.constructor
+ - EltwiseDivTest.constructor
+ - PushTest.constructor
+ - EltwiseAddTest.constructor
+ - Conv2DTest.constructor
+ - EltwiseMaxTest.constructor
+ - Reshape_Fixed_Test.constructor
+ - TransposeTest.constructor
+ - ConstGenTest.constructor
+ - FeatureBiasAddTest.alias
+ - DepthwiseFilterDecodeTest.constructor
+ - ReLU6Test.constructor
+ - FeatureDecodeTest.constructor
+ - TensorBiasAddTest.constructor
+ - NodeExecution_ReLU6.f32
+ - CircleSqrtTest.constructor_P
+ - CircleRsqrtTest.constructor
+ - LINEAR_DOCUMENT.append_empty_string
+ - LINEAR_DOCUMENT.indent
+ - LINEAR_DOCUMENT.append_multi_line_text
+ - LINEAR_DOCUMENT.append_void
+ - LINEAR_DOCUMENT.document_append
+ - LINEAR_DOCUMENT.formatted_append
+ - LINEAR_DOCUMENT.forward_append
+ - LINEAR_DOCUMENT.reverse_append
+ - NodeData.as_s32_buffer_wrapper
+ - NodeData.as_f32_buffer_wrapper
+ - ConsoleReporterTest.constructor
+ - ConsoleReporterTest.notify
+ - NodeExecution_TensorBiasAdd.f32
+ - NodeExecution_FeatureBiasAdd.f32
+ - ADT_KERNEL_SHAPE.num_elements
+ - ADT_KERNEL_SHAPE.operator_eq
+ - ADT_KERNEL_SHAPE.ctor
+ - CircleLogicalNotTest.constructor_P
+ - CircleConcatenationTest.constructor_P
+ - ModuleTest.add_more
+ - ModuleTest.consturctor
+ - ModuleTest.add
+ - ModuleTest.add_nullptr_NEG
+ - ModuleTest.graph_index_overflow_NEG
+ - CircleArgMaxTest.constructor_P
+ - CircleReshapeTest.alloc_new_shape_P
+ - CircleReshapeTest.constructor_P
+ - CircleAddTest.constructor_P
+ - CanonicalShapeInferenceRuleTest.tensor_concat
+ - CanonicalShapeInferenceRuleTest.feature_codec
+ - CanonicalShapeInferenceRuleTest.maxpool2d
+ - CanonicalShapeInferenceRuleTest.minimal
+ - CanonicalShapeInferenceRuleTest.const_gen
+ - CanonicalShapeInferenceRuleTest.depthwiseconv2d
+ - CanonicalShapeInferenceRuleTest.infer_v2
+ - CanonicalShapeInferenceRuleTest.avgpool2d
+ - CanonicalShapeInferenceRuleTest.tensor_broadcast
+ - CanonicalShapeInferenceRuleTest.transposedconv2d
+ - CanonicalShapeInferenceRuleTest.fixed_reshape
+ - CanonicalShapeInferenceRuleTest.relu
+ - CanonicalShapeInferenceRuleTest.tensor_transpose
+ - NodeExecution_Softmax.f32
+ - CircleCosTest.constructor_P
+ - HermesTest.simple_usecase
+ - CircleMaxPool2DTest.constructor_P
+ - GraphTest.graph_node_enumeration
+ - GraphTest.graph_name
+ - GraphTest.create_input
+ - NamedTest.constructor
+ - NamedTest.setter_and_getter
+ - GraphTest.create_and_destroy_node
+ - GraphTest.graph_name_nullptr_NEG
+ - DataTypedMixinTest.constructor
+ - DataTypedMixinTest.setter_and_getter
+ - GraphTest.consturctor_with_param_node
+ - TensorShapedMixinTest.setter_and_getter
+ - GraphTest.getters_over_const_instance
+ - GraphTest.create_output
+ - GraphTest.graph_inout_enumeration
+ - StrideTest.default_constructor_2D
+ - StrideTest.setter_and_getter_2D
+ - ADT_FEATURE_CHW_LAYOUT.col_increase
+ - ADT_FEATURE_CHW_LAYOUT.ch_increase
+ - ADT_FEATURE_CHW_LAYOUT.row_increase
+ - TensorIndexTest.copy
+ - TensorIndexTest.fill
+ - TensorIndexTest.at
+ - TensorIndexTest.ctor_initializer_list
+ - TensorIndexTest.resize
+ - TensorIndexTest.ctor
+ - NodeDataImpl.as_annotation
+ - MUILTI_LINE_TEXT_UTILS.operator_shift
+ - SeverityTest.fatal
+ - SeverityTest.warn
+ - SeverityTest.error
+ - SeverityTest.info
+ - SeverityTest.verbose
+ - MessageTest.ctor
+ - MessageTextTest.multiline
+ - NodeExecution_TransposedConv2D.f32
+ - ADT_FEATURE_BUFFER.ctor
+ - ADT_FEATURE_BUFFER.access
+ - UseTest.constructor
+ - UseTest.link_node
+ - NodeExecution_FilterEncode.f32
+ - NodeExecution_FilterEncode.s32
+ - CircleTransposeTest.constructor_P
+ - DimensionTest.value_constructor
+ - DimensionTest.default_constructor
+ - DimensionTest.operator_eq
+ - DimensionTest.make_unknown_dimension
+ - DimensionTest.unset
+ - DimensionTest.set
+ - TensorShapeTest.ctor_initializer_list
+ - TensorShapeTest.eq_negative_on_unmatched_dim
+ - TensorShapeTest.copy
+ - TensorShapeTest.eq_negative_on_unmatched_rank
+ - TensorShapeTest.dim
+ - TensorShapeTest.resize
+ - TensorShapeTest.eq_positive
+ - TensorShapeTest.ctor
+ - TensorFlowLiteImport.Dummy
+ - CircleTransposeConvTest.constructor_P
+ - LOCO.identity_network
+ - CanonicalDialectTest.get
+ - FeatureIndexTest.default_constructor
+ - FeatureIndexTest.settet_and_getter
+ - ADT_FEATURE_LAYOUT.move
+ - ADT_FEATURE_LAYOUT.ctor
+ - ADT_FEATURE_LAYOUT.copy
+ - CircleSoftmaxTest.constructor_P
+ - CanonicalNodeTest.mutable_visitor
+ - CanonicalNodeTest.visitor
+ - CanonicalNodeTest.visitor_with_user_default_impl
+ - NodeExecution_ReLU.f32
+ - ShapeInferenceTest.framework
+ - NodeExecution_EltwiseSqrt.f32
+ - NodeExecution_MatrixCodec.WH_f32
+ - NodeExecution_MatrixCodec.HW_s32
+ - ADT_FEATURE_SHAPE.operator_eq
+ - ADT_FEATURE_SHAPE.ctor
+ - ADT_FEATURE_SHAPE.num_elements
+ - SET.operator_diff
+ - SET.operator_eq
+ - NodeExecution_ConstGen.s32
+ - NodeExecution_ConstGen.f32
+ - CircleMulTest.constructor_P
+ - StrCastTests.safe_strcast_int
+ - NodeExecution_EltwiseMax.f32
+ - NodeExecution_Pull.check_data_ready
+ - FormattedTensorShapeTest.BracketFormat
+ - FilterShapeTest.settet_and_getter
+ - FilterShapeTest.default_constructor
+ - NodeExecution_MaxPool2D.with_padding
+ - NodeExecution_MaxPool2D.f32_1x3x3x1_calculation
+ - NodeExecution_EltwiseAdd.f32
+ - ADT_KERNEL_LAYOUT.move
+ - ADT_KERNEL_LAYOUT.ctor
+ - ADT_KERNEL_LAYOUT.copy
+ - NodeExecution_MatMul.s32_4x2_2x6
+ - NodeExecution_MatMul.f32_2x3_3x3
+ - CircleDepthwiseConv2DTest.constructor_P
+ - NodeExecution_Forward.s32
+ - NodeExecution_Forward.f32
+ - NodeExecution_EltwiseMul.f32
+ - FilterIndexTest.default_constructor
+ - FilterIndexTest.settet_and_getter
+ - DialectTest.service
+ - Session.inference_identity
+ - Session.dtor
+ - Session.session_for_subgraph
+ - Session.set_input
+ - Session.ctor_by_range
+ - Session.graph_IO_size
+ - NodeTest.constructor
+ - NodeTest.replace_with
+ - NodeTest.succs
+ - NodeTest.preds
+ - FixedArityNodeTest.constructor
+
+ negativeTestCase:
+ - condition:
+ - testName:
+ ends:
+ - _NEG
+
+ positiveTestCase:
+ - condition:
+ - inverse: negativeTestCase
printf "21.02" > $(OVERLAY_FOLDER)/ARMCOMPUTE.stamp
endif
+ifneq ($(DEBIAN_BUILD),)
+ test -d externals || mkdir -p externals
+ find packaging/ -type f -name "*.tar.gz" | xargs -i tar xf {} -C externals
+endif
+
NNFW_WORKSPACE="$(WORKSPACE)" NNFW_INSTALL_PREFIX=$(INSTALL_PATH) ./nnfw configure \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE_LC) \
-DNNFW_OVERLAY_DIR=$(OVERLAY_FOLDER) \
+++ /dev/null
-version: 2
-test:
- - name: NN Compiler
- testCaseLanguage: CPP
- testFW: GTEST
- testCaseFolder:
- - ./angkor
- - ./arser
- - ./circle2circle
- - ./circle-quantizer
- - ./crew
- - ./cwrap
- - ./foder
- - ./hermes
- - ./hermes-std
- - ./loco
- - ./locomotiv
- - ./locop
- - ./logo
- - ./logo-core
- - ./luci
- - ./luci-interpreter
- - ./luci-eval-driver
- - ./luci-pass-value-test
- - ./luci-value-test
- - ./mio-circle
- - ./mio-tflite
- - ./oops
- - ./pepper-assert
- - ./pepper-str
- - ./pepper-strcast
- - ./pp
- - ./record-minmax
- - ./safemain
- - ./souschef
- - ./tflite2circle
-
- testFile:
- - extension: .test.cpp
- any: true
-
- testCase:
- - condition:
- - functionName:
- starts:
- - TEST
-
- negativeTestCase:
- - condition:
- - testName:
- ends:
- - _NEG
-
- positiveTestCase:
- - condition:
- - inverse: negativeTestCase
set(BCQ_TOOLS_FILES
- generate_bcq_metadata
- generate_bcq_output_arrays
generate_bcq_metadata.py
generate_bcq_output_arrays.py
)
+++ /dev/null
-#!/usr/bin/env python3
-
-# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-import tensorflow as tf
-
-import argparse
-import sys
-
-ONE_START_MAGICNUM = int(-2e9 + 27)
-ONE_END_MAGICNUM = int(2e9 - 27)
-
-
-def _get_parser():
- """
- Returns an ArgumentParser for generating BCQ metadata.
- """
- parser = argparse.ArgumentParser(
- description=("Command line tool to generate metadata of BCQ nodes"))
-
- # Input and output path.
- parser.add_argument(
- "-i",
- "--input_path",
- type=str,
- help="Full filepath of the input file.",
- required=True)
- parser.add_argument(
- "-o",
- "--output_path",
- type=str,
- help="Full filepath of the output file.",
- required=True)
- parser.add_argument(
- "-O",
- "--output_arrays",
- type=str,
- help="Original model output arrays",
- required=True)
-
- return parser
-
-
-# This function is copied from
-# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
-def load_graph(model_file):
- graph = tf.Graph()
- graph_def = tf.compat.v1.GraphDef()
-
- with open(model_file, "rb") as f:
- graph_def.ParseFromString(f.read())
- with graph.as_default():
- tf.import_graph_def(graph_def, name="")
-
- return graph
-
-
-def generate_metadata_header(original_graph, bcq_version, output_arrays):
- # Generating metadata starts
- metadata_values = np.array([ONE_START_MAGICNUM])
-
- # Append BCQ version
- metadata_values = np.append(metadata_values, bcq_version)
-
- # Append original output count
- output_cnt = output_arrays.count(',') + 1
- metadata_values = np.append(metadata_values, output_cnt)
-
- return metadata_values
-
-
-def generate_bcq_metadata_v1(flags):
- """
- BCQv1 contains following metadata.
- - The number of each BCQ information set
- """
-
- is_valid = True
- allowed_info_names = [
- "bcqinfo_do_w_x", "bcqinfo_alpha", "bcqinfo_packed_binary_code",
- "bcqinfo_number_of_clusters", "bcqinfo_size_of_clusters",
- "bcqinfo_qbits_of_clusters", "bcqinfo_dequant_weight"
- ]
-
- original_graph = load_graph(flags.input_path)
- original_graph_def = original_graph.as_graph_def()
-
- prefix_infonames_dict = {}
-
- for node in original_graph_def.node:
- if node.op == "Const" and "/bcqinfo_" in node.name:
- prefix_index = node.name.index("/bcqinfo_")
- prefix = node.name[:prefix_index]
- infoname = node.name[prefix_index + 1:]
-
- if infoname not in allowed_info_names:
- is_valid = False
- break
-
- if prefix not in prefix_infonames_dict:
- prefix_infonames_dict[prefix] = set()
-
- prefix_infonames_dict[prefix].add(infoname)
-
- # All the number of BCQ information should be same
- num_of_bcqinfo = -1
- for key in prefix_infonames_dict:
- infonames = prefix_infonames_dict[key]
- if num_of_bcqinfo == -1:
- num_of_bcqinfo = len(infonames)
- elif num_of_bcqinfo != len(infonames):
- is_valid = False
-
- # The number of BCQv1 information should be 6 or 7
- if num_of_bcqinfo != 6 and num_of_bcqinfo != 7:
- is_valid = False
-
- # If BCQ information is invalid, return original model
- if is_valid == False:
- return original_graph_def
-
- new_graph_def = tf.compat.v1.GraphDef()
- for node in original_graph_def.node:
- new_node = new_graph_def.node.add()
- new_node.CopyFrom(node)
-
- # Generate metadata header
- metadata_values = generate_metadata_header(original_graph, 1, flags.output_arrays)
-
- # Append metadata of BCQv1
- metadata_values = np.append(metadata_values, num_of_bcqinfo + 1)
-
- # Finish generating metadata
- metadata_values = np.append(metadata_values, ONE_END_MAGICNUM)
-
- # Generate metadata tensor
- metadata_tensor = tf.make_tensor_proto(metadata_values, tf.int32)
-
- new_node = new_graph_def.node.add()
- new_node.op = "Const"
- new_node.name = "one_compiler/bcqinfo_one_metadata"
- new_node.attr["dtype"].CopyFrom(
- tf.core.framework.attr_value_pb2.AttrValue(type=tf.int32.as_datatype_enum))
- new_node.attr["value"].tensor.CopyFrom(metadata_tensor)
- return new_graph_def
-
-
-def determine_bcq_version(flags):
- """
- CAUTION : For now, BCQ has only one version and thus always returns 1 when BCQ
- information nodes are included. If new BCQ version is introduced,
- this function must be updated accordingly.
-
- When BCQ information does not exist, -1 is returned.
- """
- bcq_version = -1
-
- original_graph = load_graph(flags.input_path)
- original_graph_def = original_graph.as_graph_def()
-
- for node in original_graph_def.node:
- if node.op == "Const" and "/bcqinfo_" in node.name:
- bcq_version = 1
- break
-
- return bcq_version
-
-
-def generate_bcq_metadata(flags):
- """
- Basic format of metadata is as following.
- - Magic number indicating start
- - Version of BCQ Format
- - The number of original outputs
- - Metadata based on each BCQ format
- - Magic number indicating end
- """
- program_version = 1
- model_version = determine_bcq_version(flags)
-
- if model_version == 1:
- result_graph_def = generate_bcq_metadata_v1(flags)
- elif model_version == -1:
- # When there is no BCQ information, do nothing
- result_graph_def = load_graph(flags.input_path)
- else:
- err_msg = "BCQ version of the model(v{}) ".format(model_version)
- err_msg += "is higher than "
- err_msg += "the version supported by this program(v{})".format(program_version)
- raise SystemExit(err_msg)
-
- tf.io.write_graph(result_graph_def, '.', flags.output_path, False)
-
-
-def main():
- # Parse argument.
- parser = _get_parser()
- flags = parser.parse_known_args(args=sys.argv[1:])
-
- # Generate a new pb file, which BCQ metadata is included.
- generate_bcq_metadata(flags[0])
-
-
-if __name__ == "__main__":
- main()
import tensorflow as tf
import argparse
+import os
import sys
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
ONE_START_MAGICNUM = int(-2e9 + 27)
ONE_END_MAGICNUM = int(2e9 - 27)
new_node.op = "Const"
new_node.name = "one_compiler/bcqinfo_one_metadata"
new_node.attr["dtype"].CopyFrom(
- tf.core.framework.attr_value_pb2.AttrValue(type=tf.int32.as_datatype_enum))
+ tf.compat.v1.AttrValue(type=tf.int32.as_datatype_enum))
new_node.attr["value"].tensor.CopyFrom(metadata_tensor)
return new_graph_def
if __name__ == "__main__":
- main()
+ try:
+ main()
+ except Exception as e:
+ prog_name = os.path.basename(__file__)
+ print(f"{prog_name}: {type(e).__name__}: " + str(e))
+ sys.exit(255)
+++ /dev/null
-#!/usr/bin/env python3
-
-# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import tensorflow as tf
-
-import argparse
-import sys
-
-
-def _get_parser():
- """
- Returns an ArgumentParser for generating output_arrays.
- """
- parser = argparse.ArgumentParser(
- description=("Command line tool to generated output_arrays of BCQ nodes"))
-
- # Input and output path.
- parser.add_argument(
- "-i",
- "--input_path",
- type=str,
- help="Full filepath of the input file.",
- required=True)
- parser.add_argument(
- "-m",
- "--metadata_path",
- type=str,
- help="Full filepath for the file that provides metadata.",
- required=True)
- parser.add_argument(
- "-A",
- "--output_arrays_path",
- type=str,
- help="Full filepath for the file that provides output arrays",
- required=True)
-
- return parser
-
-
-# This function is copied from
-# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
-def load_graph(model_file):
- graph = tf.Graph()
- graph_def = tf.compat.v1.GraphDef()
-
- with open(model_file, "rb") as f:
- graph_def.ParseFromString(f.read())
- with graph.as_default():
- tf.import_graph_def(graph_def, name="")
-
- return graph
-
-
-def find_bcq_version(flags):
- """
- If BCQ metadata exists, BCQ version is in the second element.
- Return -1 when the metadata is not found.
- """
- graph = load_graph(flags.input_path)
- graph_def = graph.as_graph_def()
- for node in graph_def.node:
- if node.op == "Const" and "one_compiler/bcqinfo_one_metadata" in node.name:
- metadata_tensor = tf.make_ndarray(node.attr["value"].tensor)
- return metadata_tensor[1]
- return -1
-
-
-def print_bcqinfo_output_arrays_v1(flags):
- """
- This function generates a file which includes output arrays of BCQ v1
- information bundles. Each bundle is consisted with one of candidate
- operations (BCQ may be applied) and BCQ constant nodes related with
- the operation.
- """
- graph = load_graph(flags.input_path)
- graph_def = graph.as_graph_def()
- ops = graph.get_operations()
-
- # If there is a constant node named PREFIX_1/bcqinfo_alpha,
- # it is used for applying BCQ to constant node named PREFIX_1.
- # Collected prefixes will be used for connecting
- # bcqinfo nodes and user operations of prefix nodes.
- prefix_set = set()
- has_dequant_weight = False
- for op in ops:
- if op.type == "Const" and "/bcqinfo_" in op.outputs[0].name:
- # Metadata do not have prefix
- if "one_compiler/bcqinfo_one_metadata" in op.outputs[0].name:
- continue
-
- prefix_index = op.outputs[0].name.index("/bcqinfo_")
- prefix = op.outputs[0].name[:prefix_index]
- prefix_set.add(prefix)
-
- # Usually, output name of op is like "outputname:0"
- # -2 is for removing ":0"
- infoname = op.outputs[0].name[prefix_index + 1:-2]
- if infoname == "bcqinfo_dequant_weight":
- has_dequant_weight = True
-
- # Write the name of metadata node
- with open(flags.metadata_path, 'w') as f_metadata:
- f_metadata.write("one_compiler/bcqinfo_one_metadata,")
-
- # Write all pairs of a constant node and related BCQ information nodes.
- with open(flags.output_arrays_path, 'w') as f_arrays:
- for prefix in prefix_set:
- f_arrays.write("," + prefix + "/bcqinfo_do_w_x")
- f_arrays.write("," + prefix + "/bcqinfo_alpha")
- f_arrays.write("," + prefix + "/bcqinfo_packed_binary_code")
- f_arrays.write("," + prefix + "/bcqinfo_number_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_size_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_qbits_of_clusters")
- f_arrays.write("," + prefix)
- if has_dequant_weight:
- f_arrays.write("," + prefix + "/bcqinfo_dequant_weight")
-
-
-def print_bcq_output_arrays(flags):
- program_version = 1
- model_version = find_bcq_version(flags)
-
- if model_version == 1:
- print_bcqinfo_output_arrays_v1(flags)
- elif model_version == -1:
- # When BCQ information not found, print nothing.
- f_metadata = open(flags.metadata_path, 'w')
- f_arrays = open(flags.output_arrays_path, 'w')
- f_metadata.close()
- f_arrays.close()
- else:
- err_msg = "BCQ version of the model(v{}) ".format(model_version)
- err_msg += "is higher than "
- err_msg += "the version supported by this program(v{})".format(program_version)
- raise SystemExit(err_msg)
-
-
-def main():
- # Parse argument.
- parser = _get_parser()
- flags = parser.parse_known_args(args=sys.argv[1:])
-
- print_bcq_output_arrays(flags[0])
-
-
-if __name__ == "__main__":
- main()
luci_interpreter::Interpreter interpreter(module.get());
// Set input
- // TODO support multiple subgraphs
- assert(module->size() == 1);
const auto input_nodes = loco::input_nodes(module->graph());
int32_t num_inputs = static_cast<int32_t>(input_nodes.size());
for (int32_t i = 0; i < num_inputs; i++)
void PModelsRunner::save_outputs(const std::string &output_file)
{
+ LOGGER(l);
+
// load source model as we need to get both shape and node name
// TODO check for unknown shape
auto source_fname = _pconfig.source.model_file;
+ INFO(l) << "save_outputs() loading file: " << source_fname << std::endl;
auto module = import_circle(source_fname);
const auto output_nodes = loco::output_nodes(module->graph());
const auto *output_node = loco::must_cast<const luci::CircleOutput *>(output_nodes[i]);
auto output_name = output_node->name();
+ INFO(l) << "save_outputs() save output node: " << output_name << std::endl;
assert(_data_stage.find(output_name) != _data_stage.end());
auto tensor_data = _data_stage[output_name];
unset(RECIPE_LIST)
unset(PARTITION_LIST)
+unset(OUTPUT_COUNT_LIST)
unset(TEST_DEPS)
-macro(add RECIPE_NAME PARTITION_NAME)
+macro(add RECIPE_NAME PARTITION_NAME OUTPUT_COUNT)
list(APPEND RECIPE_LIST ${RECIPE_NAME})
list(APPEND PARTITION_LIST ${PARTITION_NAME})
+ list(APPEND OUTPUT_COUNT_LIST ${OUTPUT_COUNT})
endmacro(add)
# Read "test.lst"
foreach(IDX RANGE ${RECIPE_LENGTH_M1})
list(GET RECIPE_LIST ${IDX} RECIPE_NAME)
list(GET PARTITION_LIST ${IDX} PARTITION_NAME)
+ list(GET OUTPUT_COUNT_LIST ${IDX} OUTPUT_COUNT)
# NOTE about the name:
# Use '.recipe' name for source tflite and circle files
add_custom_command(OUTPUT ${TFLITE_DST_PATH}
COMMAND ${CMAKE_COMMAND} -E copy "${TFLITE_SRC_PATH}" "${TFLITE_DST_PATH}"
- DEPENDS ${TFLITE_SRC_PATH}
+ DEPENDS ${TFLITE_SRC_PATH} ${PARTITIONER_OUTPUT_PATH}
COMMENT "Copy ${RECIPE_NAME}.tflite"
)
list(APPEND TEST_DEPS ${TFLITE_DST_PATH})
add_custom_command(OUTPUT ${CIRCLE_DST_PATH}
COMMAND ${CMAKE_COMMAND} -E copy "${CIRCLE_SRC_PATH}" "${CIRCLE_DST_PATH}"
- DEPENDS ${CIRCLE_SRC_PATH}
+ DEPENDS ${CIRCLE_SRC_PATH} ${PARTITIONER_OUTPUT_PATH}
COMMENT "Copy ${RECIPE_NAME}.circle"
)
list(APPEND TEST_DEPS ${CIRCLE_DST_PATH})
add_custom_command(OUTPUT ${PART_DST_PATH}
COMMAND ${CMAKE_COMMAND} -E copy "${PART_SRC_PATH}" "${PART_DST_PATH}"
- DEPENDS ${PART_SRC_PATH}
+ DEPENDS ${PART_SRC_PATH} ${PARTITIONER_OUTPUT_PATH}
COMMENT "Copy ${PART_FILE}"
)
list(APPEND TEST_DEPS ${PART_DST_PATH})
COMMENT "Parition ${RECIPE_NAME}.circle with ${PART_FILE}"
)
list(APPEND TEST_DEPS ${PARTITIONER_CONN_JSON})
+
+ # Write .excnt file; expected count of output models
+ set(COUNT_FILE "${PARTITION_NAME}.excnt")
+ set(COUNT_FILE_PATH "${PARTITIONER_OUTPUT_PATH}/${COUNT_FILE}")
+ add_custom_command(OUTPUT ${COUNT_FILE_PATH}
+ COMMAND echo ${OUTPUT_COUNT} > ${COUNT_FILE_PATH}
+ DEPENDS ${PART_SRC_PATH} ${PARTITIONER_OUTPUT_PATH}
+ COMMENT "Write ${COUNT_FILE} with ${OUTPUT_COUNT}"
+ )
+ list(APPEND TEST_DEPS ${COUNT_FILE_PATH})
endforeach(IDX)
add_custom_target(circle_part_value_test_prepare ALL DEPENDS ${TEST_DEPS})
import subprocess
import argparse
import traceback
+import json
#
# This script compares the execution result of TFLite interpreter and
tflite_model = args.name + ".tflite"
circle_model = args.name + ".circle"
partition_conn_ini = args.name + ".conn.ini"
+partition_conn_json = args.name + ".conn.json"
+expected_count = args.name + ".excnt"
+
+# Check expected count of models from partitioning
+try:
+ with open(expected_count, "r") as expected_count_file:
+ expected_count_line = expected_count_file.readline()
+
+ expected_count_line = int(expected_count_line)
+ if expected_count_line:
+ with open(partition_conn_json) as json_file:
+ json_data = json.load(json_file)
+ parts_value = json_data["parts"]
+ if len(parts_value) != expected_count_line:
+ print("Partitioned model count differs from expected:",
+ expected_count_line)
+ quit(255)
+
+ print("Partitioned model count expected: ", expected_count_line)
+ else:
+ print("Skip expected partitioned model count check: 0")
+
+except:
+ print("Skip expected partitioned model count check: error")
# Build TFLite interpreter.
interpreter = tf.lite.Interpreter(tflite_model)
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opname
+
+[OPNAME]
+Mean_as_variance=acl_cl
+Add_as_variance=acl_cl
+Pow=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opname
+
+[OPNAME]
+add1=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opname
+
+[OPNAME]
+some/node/add2;and/another=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opname
+
+[OPNAME]
+add1=cpu
+add2=acl_cl
+ofm=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opname
+
+[OPNAME]
+add1=acl_cl
+add2=acl_cl
+ofm=cpu
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+ADD=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+ADD=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+MAXIMUM=acl_cl
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+MAXIMUM=acl_cl
# Add recipe names from /res/TensorFlowLiteRecipes to test.
# Only add items exist in common-artifacts test: tflite/circle files are copied as source.
#
-# add(RECIPE_NAME PARTITION_NAME)
+# add(RECIPE_NAME PARTITION_NAME EXPECTED_OUTPUT_COUNT)
+# EXPECTED_OUTPUT_COUNT: 0 for skip expected count test
-add(Part_Add_Sub_000 Part_Add_Sub_000)
-add(Part_Sqrt_Rsqrt_000 Part_Sqrt_Rsqrt_000)
-add(Part_Sqrt_Rsqrt_001 Part_Sqrt_Rsqrt_001)
-add(Part_Sqrt_Rsqrt_002 Part_Sqrt_Rsqrt_002)
-add(Part_Sqrt_Rsqrt_003 Part_Sqrt_Rsqrt_003)
-add(Part_Sqrt_Rsqrt_Add_000 Part_Sqrt_Rsqrt_Add_000)
-add(Part_Sqrt_Rsqrt_Add_001 Part_Sqrt_Rsqrt_Add_001)
-add(Part_Sqrt_Rsqrt_Add_002 Part_Sqrt_Rsqrt_Add_002)
-add(Part_Sqrt_Rsqrt_Add_003 Part_Sqrt_Rsqrt_Add_003)
-add(Part_Sqrt_Rsqrt_Add_004 Part_Sqrt_Rsqrt_Add_004)
-add(Part_Add_Sqrt_000 Part_Add_Sqrt_000)
-add(Part_Add_Sqrt_Rsqrt_000 Part_Add_Sqrt_Rsqrt_000)
-add(Net_InstanceNorm_003 Net_InstanceNorm_003)
-add(Net_InstanceNorm_003 Net_InstanceNorm_003.001)
-add(Net_InstanceNorm_003 Net_InstanceNorm_003.002)
+add(Part_Add_Sub_000 Part_Add_Sub_000 2)
+add(Part_Sqrt_Rsqrt_000 Part_Sqrt_Rsqrt_000 2)
+add(Part_Sqrt_Rsqrt_001 Part_Sqrt_Rsqrt_001 2)
+add(Part_Sqrt_Rsqrt_002 Part_Sqrt_Rsqrt_002 4)
+add(Part_Sqrt_Rsqrt_003 Part_Sqrt_Rsqrt_003 3)
+add(Part_Sqrt_Rsqrt_Add_000 Part_Sqrt_Rsqrt_Add_000 3)
+add(Part_Sqrt_Rsqrt_Add_001 Part_Sqrt_Rsqrt_Add_001 3)
+add(Part_Sqrt_Rsqrt_Add_002 Part_Sqrt_Rsqrt_Add_002 4)
+add(Part_Sqrt_Rsqrt_Add_003 Part_Sqrt_Rsqrt_Add_003 1)
+add(Part_Sqrt_Rsqrt_Add_004 Part_Sqrt_Rsqrt_Add_004 1)
+add(Part_Add_Sqrt_000 Part_Add_Sqrt_000 3)
+add(Part_Add_Sqrt_Rsqrt_000 Part_Add_Sqrt_Rsqrt_000 3)
+add(Net_InstanceNorm_003 Net_InstanceNorm_003 3)
+add(Net_InstanceNorm_003 Net_InstanceNorm_003.001 5)
+# skip expected count for now
+add(Net_InstanceNorm_003 Net_InstanceNorm_003.002 0)
+
+# comply=opname
+add(Part_Add_Sub_000 Part_Add_Sub_000.001 3)
+add(Part_Add_Sub_001 Part_Add_Sub_001 3)
+add(Part_Add_Sub_002 Part_Add_Sub_002.001 2)
+add(Part_Add_Sub_002 Part_Add_Sub_002.002 2)
+add(Net_InstanceNorm_003 Net_InstanceNorm_003.003 3)
+
+# IF with subgraphs
+add(Part_If_Add_Sub_000 Part_If_Add_Sub_000.001 3)
+add(Part_If_Add_Sub_001 Part_If_Add_Sub_001.001 3)
+
+# WHILE with subgraphs
+add(Part_While_000 Part_While_000 3)
+add(Part_While_001 Part_While_001 3)
--- /dev/null
+# NOTE Test below are for circle-partitioner is partitioning itself.
+# Once this test passes, add partition to 'circle-part-value-test' for
+# full test.
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+
+unset(RECIPE_LIST)
+unset(PART_LIST)
+unset(TEST_DEPS)
+
+macro(add RECIPE_NAME PART_NAME)
+ list(APPEND RECIPE_LIST ${RECIPE_NAME})
+ list(APPEND PART_LIST ${PART_NAME})
+endmacro(add)
+
+# Read "test.lst"
+include("test.lst")
+
+list(LENGTH RECIPE_LIST RECIPE_LENGTH)
+math(EXPR RECIPE_LENGTH_M1 "${RECIPE_LENGTH} - 1")
+
+foreach(IDX RANGE ${RECIPE_LENGTH_M1})
+ list(GET RECIPE_LIST ${IDX} RECIPE_NAME)
+ list(GET PART_LIST ${IDX} PART_NAME)
+
+ set(PART_OUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/${PART_NAME}")
+
+ add_custom_command(OUTPUT ${PART_OUT_PATH}
+ COMMAND ${CMAKE_COMMAND} -E make_directory "${PART_OUT_PATH}"
+ COMMENT "Make directory ${PART_OUT_PATH}"
+ )
+
+ set(CIRCLE_SRC_PATH "${ARTIFACTS_BIN_PATH}/${RECIPE_NAME}.circle")
+ set(CIRCLE_DST_PATH "${PART_OUT_PATH}/${PART_NAME}.circle")
+
+ # Copy circle
+ add_custom_command(OUTPUT ${CIRCLE_DST_PATH}
+ COMMAND ${CMAKE_COMMAND} -E copy "${CIRCLE_SRC_PATH}" "${CIRCLE_DST_PATH}"
+ DEPENDS ${CIRCLE_SRC_PATH}
+ COMMENT "Copy ${RECIPE_NAME}.circle"
+ )
+
+ set(PART_FILE "${PART_NAME}.part")
+ set(PART_SRC_PATH "${CMAKE_CURRENT_SOURCE_DIR}/parts/${PART_FILE}")
+ set(PART_DST_PATH "${PART_OUT_PATH}/${PART_FILE}")
+
+ # Copy .part
+ add_custom_command(OUTPUT ${PART_DST_PATH}
+ COMMAND ${CMAKE_COMMAND} -E copy "${PART_SRC_PATH}" "${PART_DST_PATH}"
+ DEPENDS ${PART_SRC_PATH}
+ COMMENT "Copy ${PART_FILE}"
+ )
+
+ # Run partitioner
+ set(PART_CONN_JSON "${PART_OUT_PATH}/${PART_NAME}.conn.json")
+ add_custom_command(OUTPUT ${PART_CONN_JSON}
+ COMMAND circle_partitioner "${PART_FILE}" "${PART_NAME}.circle" "${PART_OUT_PATH}"
+ DEPENDS circle_partitioner ${CIRCLE_DST_PATH} ${PART_DST_PATH}
+ COMMENT "Parition ${RECIPE_NAME}.circle with ${PART_FILE}"
+ )
+ # NOTE this is checked in build time and not added with 'add_test' command
+ # to reduce scripts to run testing. actual testing is done in 'circle-part-evel'
+
+ list(APPEND TEST_DEPS ${CIRCLE_DST_PATH} ${PART_DST_PATH} ${PART_CONN_JSON})
+endforeach(IDX)
+
+add_custom_target(circle_partitioner_test ALL DEPENDS ${TEST_DEPS})
+add_dependencies(circle_partitioner_test common_artifacts_deps)
--- /dev/null
+# circle-partitioner-test
+
+_circle-partitioner-test_ provides test of circle-partitioner;
+to test partitioning is working correctly, without value testing.
+- full value testing is done with _circle-part-value-test_.
+
+Purpose of this test is to check how the partitioning itself is done
+before value testing, in local. After you've checked model partitioning is
+working as you expect, you can add test to _circle-part-value-test_.
+
+It is not necessary to commit to test of this module to upstream.
--- /dev/null
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+DIV=acl_cl
--- /dev/null
+require("circle-partitioner")
+require("common-artifacts")
--- /dev/null
+# Add recipes in /res/TensorFlowLiteRecipes to test.
+# NOTE: only add items exist in common-artifacts test: circle files are copied
+# from common-artifacts.
+# Use this list file before end-to-end test in 'circle-part-value-test'.
+# add(RECIPE_NAME PART_NAME)
+
+add(Net_InstanceNorm_003 Net_InstanceNorm_003)
target_link_libraries(circle_partitioner luci_log)
target_link_libraries(circle_partitioner luci_import)
target_link_libraries(circle_partitioner luci_service)
+target_link_libraries(circle_partitioner luci_pass)
target_link_libraries(circle_partitioner luci_export)
target_link_libraries(circle_partitioner luci_partition)
target_link_libraries(circle_partitioner arser)
+target_link_libraries(circle_partitioner pepper_csv2vec)
target_link_libraries(circle_partitioner vconone)
target_link_libraries(circle_partitioner nncc_common)
# circle-partitioner
_circle-partitioner_ provides model partitioning of circle model to two or more circle models.
+
+## How circle-partitioner work
+
+_circle-partitioner_ requires 3 positional arguments
+- first: `partition` file
+- second: `input` circle model file
+- third: `work` folder
+
+And options to override `partition` file as a helper to try out without editing `partition` file.
+- `--backends`: override `backends` of `[partition]` section
+- `--default`: override `default` of `[partition]` section
+
+_circle-partitoner_ will read the `partition` and `input` files and group nodes with same backend
+and store them into new circle models in `work` folder, where the `partition` and `input` files
+are read from `work` folder.
+
+Outputs are (1) one or more partitioned circle models and (2) connection file that gives how
+the partitioned models should be connected to act like the source `input` model.
+
+Why does input files be placed in `work` folder too?
+- this is still work in progress condition
+- use cases are still ambigious
+- original `input` model file can be used by the backend, so `.conn` file links it as `source`
+- to make things simple for the backend, it would be better not to use relative path for the files
+
+### `partition` file
+
+`partition` follows INI format of _crew_ project.
+
+Several example files exist in _circle-part-value-test_ `parts` folder.
+
+This section will explain with `Net_InstanceNorm_003.part` file as example.
+```ini
+[partition]
+backends=cpu,acl_cl
+default=cpu
+comply=opcode
+
+[OPCODE]
+DIV=acl_cl
+```
+
+##### `[partition]` section
+
+`[partition]` section is the main section first to read.
+- `backends`: Existing partition group names which nodes should be placed, in CSV format.
+- `default`: Default group name which should be one of `backends` item.
+- `comply`: How to group nodes of the model.
+ - currently `opcode` is supported
+ - future work: set group by node name or sequence number.
+
+##### `[OPCODE`] section
+
+This section provides how to group nodes in OPCODE types.
+Nodes with same OPCODE will be grouped to that type.
+This does not mean number of output circle files will be same as number of backends.
+Number of output circle files will depend also on the network structure.
+
+For above example, all `DIV` OPCODE nodes will be grouped to `acl_cl` backend.
+
+`[OPCODE]` can override `default` backend set from `[partition]` section by using `_`.
+
+For example, we can change default to `cpu`.
+```
+[OPCODE]
+_=cpu
+DIV=acl_cl
+```
+
+### `circle` file
+
+Just normal `circle` file. Currently partition is supported in limited properties and
+models with these properties are not support yet;
+- Have multiple subgraph models
+- Operators with multiple output nodes such as IF or WHILE.
+
+### `work` folder
+
+`partition` and `circle` file should reside in `work` folder. Output files will be
+generated inside this folder.
+
+### Example
+
+Typical source of paritioning
+```
+$ tree Net_InstanceNorm_003/
+Net_InstanceNorm_003/
+├── Net_InstanceNorm_003.circle
+└── Net_InstanceNorm_003.part
+```
+
+Command example
+```
+./circle_partitioner Net_InstanceNorm_003.part Net_InstanceNorm_003.circle Net_InstanceNorm_003
+```
+
+Result of _circle-partitioner_
+```
+$ tree Net_InstanceNorm_003/
+Net_InstanceNorm_003/
+├── Net_InstanceNorm_003.00001_cpu.circle
+├── Net_InstanceNorm_003.00002_acl_cl.circle
+├── Net_InstanceNorm_003.00003_cpu.circle
+├── Net_InstanceNorm_003.circle
+├── Net_InstanceNorm_003.conn.ini
+├── Net_InstanceNorm_003.conn.json
+└── Net_InstanceNorm_003.part
+```
+
+### `Net_InstanceNorm_003.conn.ini` and `Net_InstanceNorm_003.conn.json`
+
+These two files are identical in content but in different formats.
+
+`.conn` file provides an information how to reconstruct the partitioned models,
+`Net_InstanceNorm_003.00001_cpu.circle`, `Net_InstanceNorm_003.00002_acl_cl.circle`
+and `Net_InstanceNorm_003.00003_cpu.circle`, so that it will identical to
+source `Net_InstanceNorm_003.circle` model in computational results.
+
+Here, meaning of `reconstruct` is connection of outputs and inputs of partitioned
+models.
+
+```json
+$ cat Net_InstanceNorm_003/Net_InstanceNorm_003.conn.json
+{
+ "source" : {
+ "file" : "Net_InstanceNorm_003.circle",
+ "inputs" : [ "Input" ],
+ "outputs" : [ "Add_as_terminal" ]
+ },
+ "parts" : [
+ {
+ "file" : "Net_InstanceNorm_003.00001_cpu.circle",
+ "inputs" : [ "Input" ],
+ "outputs" : [ "Pow", "Sub" ]
+ },
+ {
+ "file" : "Net_InstanceNorm_003.00002_acl_cl.circle",
+ "inputs" : [ "Sub", "Pow" ],
+ "outputs" : [ "Div" ]
+ },
+ {
+ "file" : "Net_InstanceNorm_003.00003_cpu.circle",
+ "inputs" : [ "Div" ],
+ "outputs" : [ "Add_as_terminal" ]
+ }
+ ]
+}
+```
+Above file is in `JSON` format with `source` file and `parts` for partitioned models.
+Each `parts` have `file` for the file, `inputs` for input nodes and `outputs`
+for output nodes.
+
+From the `source` we can identify inputs and outputs for the model.
+
+- Each items in `outputs` should connect to `inputs` of another item of `parts` model,
+or should be one of the `outputs` of the `source` model.
+- For first `Net_InstanceNorm_003.00001_cpu.circle` model, `inputs` is(are) same
+as the `source` model: `[ "Input" ]`.
+- `outputs` `[ "Pow", "Sub" ]` have same names in the second model
+`Net_InstanceNorm_003.00002_acl_cl.circle` which they should be connected.
+- And `outputs` `[ "Div" ]` should be connected to `inputs` of
+third model `Net_InstanceNorm_003.00003_cpu.circle`.
require("foder")
require("crew")
+require("pepper-csv2vec")
require("safemain")
require("luci")
require("arser")
#include "PartitionRead.h"
#include "PartitionExport.h"
#include "HelperPath.h"
-#include "HelperStrings.h"
#include <foder/FileLoader.h>
#include <luci/Service/Validate.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
+#include <luci/CircleOptimizer.h>
+#include <luci/PartitionDump.h>
+#include <luci/PartitionValidate.h>
#include <luci/Log.h>
+#include <pepper/csv2vec.h>
#include <arser/arser.h>
#include <vconone/vconone.h>
return importer.importModule(circle_model);
}
-bool validate_module(luci::Module *module)
-{
- for (size_t g = 0; g < module->size(); ++g)
- {
- auto graph = module->graph(g);
- if (!luci::validate(graph))
- {
- std::cerr << "ERROR: Invalid circle model" << std::endl;
- return false;
- }
- if (!luci::validate_name(graph))
- {
- std::cerr << "ERROR: circle model has empty name" << std::endl;
- return false;
- }
- }
-
- if (!luci::validate_unique_name(module))
- {
- std::cerr << "ERROR: circle model has duplicate names" << std::endl;
- return false;
- }
-
- return true;
-}
-
-bool validate_partition(luci::PartitionTable &partition)
-{
- if (partition.groups.size() == 0)
- {
- std::cerr << "There is no 'backends' information";
- return false;
- }
- if (partition.default_group.empty())
- {
- std::cerr << "There is no 'default' backend information";
- return false;
- }
- if (!partee::is_one_of(partition.default_group, partition.groups))
- {
- std::cerr << "'default' backend is not one of 'backends' item";
- return false;
- }
- for (auto &byopcode : partition.byopcodes)
- {
- if (!partee::is_one_of(byopcode.second, partition.groups))
- {
- std::cerr << "OPCODE " << byopcode.first << " is not assigned to one of 'backends' items";
- return false;
- }
- }
- return true;
-}
-
-void dump(std::ostream &os, const luci::PartitionTable &table)
-{
- os << "Backends:";
- for (auto &group : table.groups)
- {
- os << " " << group;
- if (table.default_group == group)
- os << "(default)";
- }
- os << std::endl;
-
- os << "Assign by OPCODE: " << std::endl;
- for (auto &item : table.byopcodes)
- os << " " << item.first << "=" << item.second << std::endl;
-}
-
-std::ostream &operator<<(std::ostream &os, const luci::PartitionTable &table)
-{
- dump(os, table);
- return os;
-}
-
} // namespace
int entry(int argc, char **argv)
{
return EXIT_FAILURE;
}
- if (!validate_module(module.get()))
+ // Run default shape/dtype inference before validation
+ // NOTE CircleWhileOut default shape is INVALID as it needs initial shape
+ // inference. This is cause of WHILE may have dynamic shape.
+ luci::CircleOptimizer optimizer;
+ (void)optimizer.options(); // need to call this to make internal member
+ for (size_t g = 0; g < module->size(); ++g)
+ {
+ auto graph = module->graph(g);
+ optimizer.optimize(graph);
+ }
+ if (!luci::validate(module.get()))
{
return EXIT_FAILURE;
}
if (arser[opt_bks])
{
auto backend_backends = arser.get<std::string>(opt_bks);
- partition.groups = partee::csv_to_vector<std::string>(backend_backends);
+ partition.groups = pepper::csv_to_vector<std::string>(backend_backends);
}
if (arser[opt_def])
{
partition.default_group = arser.get<std::string>(opt_def);
}
}
- if (!validate_partition(partition))
+ if (!luci::validate(partition))
{
+ // NOTE error reason/message is put to std::cerr inside validate()
return EXIT_FAILURE;
}
*/
#include "PartitionRead.h"
-#include "HelperStrings.h"
#include <crew/PConfigIni.h>
#include <crew/PConfigIniDump.h>
#include <luci/Log.h>
+#include <pepper/csv2vec.h>
#include <stdexcept>
const char *_section_partition = "partition";
const char *_section_OPCODE = "OPCODE";
+const char *_section_OPNAME = "OPNAME";
+
+const char *_comply_opcode = "opcode";
+const char *_comply_opname = "opname";
const char *_key_backends = "backends";
const char *_key_default = "default";
+const char *_key_comply = "comply";
const char *_key_underscore = "_";
luci::PartitionTable parse_table(const crew::Sections §ions)
{
luci::PartitionTable table;
+ // default comply as OPCODE
+ table.comply = luci::PartitionTable::COMPLY::OPCODE;
+
+ // read main '[partition]' first
for (auto §ion : sections)
{
if (section.name == _section_partition)
throw std::invalid_argument("'default' is required");
}
- table.groups = csv_to_vector<std::string>(items.at(_key_backends));
+ table.groups = pepper::csv_to_vector<std::string>(items.at(_key_backends));
table.default_group = items.at(_key_default);
+
+ auto comply = items.at(_key_comply);
+
+ // check valid comply types
+ if (comply == _comply_opcode)
+ {
+ table.comply = luci::PartitionTable::COMPLY::OPCODE;
+ continue;
+ }
+ if (comply == _comply_opname)
+ {
+ table.comply = luci::PartitionTable::COMPLY::OPNAME;
+ continue;
+ }
+ throw std::runtime_error("Invalid or comply is not set");
}
- else if (section.name == _section_OPCODE)
+ }
+
+ // read other sections
+ for (auto §ion : sections)
+ {
+ if (section.name == _section_OPCODE)
{
auto &items = section.items;
for (auto &item : items)
{
if (item.first == _key_underscore)
- table.default_group = item.second;
+ {
+ if (table.comply == luci::PartitionTable::COMPLY::OPCODE)
+ table.default_group = item.second;
+ }
else
{
table.byopcodes.emplace(item.first, item.second);
}
}
}
+ else if (section.name == _section_OPNAME)
+ {
+ auto &items = section.items;
+
+ for (auto &item : items)
+ {
+ if (item.first == _key_underscore)
+ {
+ if (table.comply == luci::PartitionTable::COMPLY::OPNAME)
+ table.default_group = item.second;
+ }
+ else
+ {
+ table.byopnames.emplace(item.first, item.second);
+ }
+ }
+ }
}
return table;
.help("Show version information and exit")
.exit_with(print_version);
+ arser.add_argument("-V", "--verbose")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("output additional information to stdout or stderr");
+
arser.add_argument(qdqw)
.nargs(3)
.type(arser::DataType::STR_VEC)
}
catch (const std::runtime_error &err)
{
- std::cout << err.what() << std::endl;
+ std::cerr << err.what() << std::endl;
std::cout << arser;
return 255;
}
}
}
+ if (arser.get<bool>("--verbose"))
+ {
+ // The third parameter of setenv means REPLACE.
+ // If REPLACE is zero, it does not overwrite an existing value.
+ setenv("LUCI_LOG", "100", 0);
+ }
+
if (arser[qdqw])
{
auto values = arser.get<std::vector<std::string>>(qdqw);
return()
endif(NOT TARGET mio_circle)
-nnas_find_package(HDF5 QUIET)
+nnas_find_package(HDF5 COMPONENTS STATIC QUIET)
if(NOT HDF5_FOUND)
message(STATUS "Build circle-tensordump: FAILED (missing HDF5)")
target_link_libraries(circle-tensordump PRIVATE foder)
target_link_libraries(circle-tensordump PRIVATE mio_circle)
target_link_libraries(circle-tensordump PRIVATE safemain)
+
+install(TARGETS circle-tensordump DESTINATION bin)
auto tensors = reader.tensors();
for (const auto &tensor : *tensors)
{
+ const auto tensor_name = tensor->name();
+ std::string tensor_name_str = tensor_name ? tensor_name->str() : "no_name";
os << std::string(70, '-') << std::endl;
- os << "[" << tensor->name()->str() << "]" << std::endl;
+ os << "[" << tensor_name_str << "]" << std::endl;
auto buff_idx = tensor->buffer();
auto buff_data_ptr = reader.buffers()->Get(buff_idx)->data();
auto quant_param = tensor->quantization();
ret.resize(rank);
for (uint32_t d = 0; d < rank; d++)
{
- ret.at(d) = dims->Get(d);
+ if (dims->Get(d) < 0)
+ throw std::runtime_error("Dimensions shouldn't be negative");
+ ret.at(d) = static_cast<hsize_t>(dims->Get(d));
}
}
}
auto tensors = reader.tensors();
for (const auto &tensor : *tensors)
{
+ // If tensor does not have name, do nothing.
+ const auto tensor_name = tensor->name();
+ if (tensor_name == nullptr)
+ {
+ assert(false && "There is no tensor name");
+ continue;
+ }
+
// create a group for each tensor whose name is its tensor name
- std::string group_name = ::mangle(tensor->name()->c_str());
+ std::string group_name = ::mangle(tensor_name->c_str());
std::unique_ptr<H5::Group> tensor_group =
std::make_unique<H5::Group>(file.createGroup(group_name));
COMMAND ${CMAKE_COMMAND} -E remove -f ${TEST_CONFIG}
COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLE_INSPECT_PATH=\"$<TARGET_FILE:circle-inspect>\"' >> ${TEST_CONFIG}
COMMAND ${CMAKE_COMMAND} -E echo 'CIRCLE_VERIFY_PATH=\"$<TARGET_FILE:circle-verify>\"' >> ${TEST_CONFIG}
- DEPENDS
+ DEPENDS
circle-inspect
circle-verify
COMMENT "Generate test configuration"
Add(Net_Conv_Add_Mul_000 PASS fuse_batchnorm_with_conv)
Add(Net_Conv_Add_Mul_001 PASS fuse_batchnorm_with_conv)
Add(Net_Conv_Add_Mul_002 PASS fuse_batchnorm_with_conv)
+Add(Net_Conv_FakeQuant_000 PASS remove_fakequant)
+Add(Net_Conv_QuantDequant_000 PASS remove_quantdequant)
Add(Net_Conv_Min_Max_000 PASS transform_min_max_to_relu6)
+Add(Net_Conv_Min_Relu_000 PASS transform_min_relu_to_relu6)
Add(Net_Conv_Relu6_000 PASS fuse_activation_function)
Add(Net_DwConv_BN_000 PASS fuse_batchnorm_with_dwconv)
Add(Net_DwConv_BN_001 PASS fuse_batchnorm_with_dwconv)
Add(Net_TConv_BN_001 PASS fuse_batchnorm_with_tconv)
Add(Net_TConv_BN_002 PASS fuse_batchnorm_with_tconv)
Add(Net_InstanceNorm_001 PASS fuse_instnorm)
-Add(Net_InstanceNorm_002 PASS fuse_instnorm)
Add(Net_InstanceNorm_003 PASS fuse_instnorm)
+Add(Net_InstanceNorm_004 PASS fuse_instnorm)
+Add(Net_InstanceNorm_005 PASS fuse_instnorm)
+Add(Net_InstanceNorm_006 PASS fuse_instnorm)
+Add(Net_InstanceNorm_007 PASS fuse_instnorm)
Add(Net_Maximum_Minimum_000 PASS transform_min_max_to_relu6)
Add(BatchMatMulV2_000 PASS resolve_customop_batchmatmul)
Add(MatMul_000 PASS resolve_customop_matmul)
Add(DepthwiseConv2D_003 PASS)
+Add(StridedSlice_003 PASS substitute_strided_slice_to_reshape)
+Add(MaxPoolWithArgmax_000 PASS resolve_customop_max_pool_with_argmax)
+Add(MaxPoolWithArgmax_001 PASS resolve_customop_max_pool_with_argmax)
+Add(MaxPoolWithArgmax_002 PASS resolve_customop_max_pool_with_argmax)
## CIRCLE RECIPE
Add(CircleBatchMatMul_000)
+
+# REGRESSION test
+
+Add(REGRESS_ONNX_Conv_BN_001 PASS
+ convert_nchw_to_nhwc
+ nchw_to_nhwc_input_shape
+ nchw_to_nhwc_output_shape
+ remove_redundant_transpose
+ substitute_transpose_to_reshape
+ remove_redundant_reshape
+ remove_unnecessary_reshape
+ fuse_batchnorm_with_conv)
+
+Add(REGRESS_ONNX_Conv_BN_Relu6_001 PASS
+ convert_nchw_to_nhwc
+ nchw_to_nhwc_input_shape
+ nchw_to_nhwc_output_shape
+ remove_redundant_transpose
+ transform_min_max_to_relu6
+ fuse_batchnorm_with_conv
+ fuse_activation_function)
+
+Add(REGRESS_ONNX_Conv_BN_MeanMean_001 PASS
+ convert_nchw_to_nhwc
+ nchw_to_nhwc_input_shape
+ nchw_to_nhwc_output_shape
+ remove_redundant_transpose
+ fuse_batchnorm_with_conv
+ fuse_activation_function
+ fuse_mean_with_mean
+ fuse_transpose_with_mean)
list(REMOVE_ITEM SOURCES ${TESTS})
add_executable(circle2circle "${SOURCES}")
-target_include_directories(circle2circle PRIVATE include)
target_include_directories(circle2circle PRIVATE src)
target_link_libraries(circle2circle foder)
target_link_libraries(circle2circle nncc_common)
nnas_find_package(GTest REQUIRED)
GTest_AddTest(circle2circle_test ${TESTS} ${SOURCES})
-target_include_directories(circle2circle_test PRIVATE include)
target_include_directories(circle2circle_test PRIVATE src)
target_link_libraries(circle2circle_test foder)
target_link_libraries(circle2circle_test nncc_common)
#include <luci/Importer.h>
#include <luci/CircleOptimizer.h>
+#include <luci/Service/ChangeOutputs.h>
#include <luci/Service/Validate.h>
#include <luci/CircleExporter.h>
#include <luci/CircleFileExpContract.h>
#include <functional>
#include <iostream>
+#include <sstream>
#include <string>
+#include <vector>
+#include <cstdlib>
using Algorithms = luci::CircleOptimizer::Options::Algorithm;
using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
std::cout << vconone::get_copyright() << std::endl;
}
+void csv_tokenize(const std::string &data, std::vector<std::string> &result)
+{
+ const char delim = ',';
+ std::string token;
+ std::stringstream ss(data);
+
+ while (std::getline(ss, token, delim))
+ result.push_back(token);
+}
+
int entry(int argc, char **argv)
{
// Simple argument parser (based on map)
.help("Show version information and exit")
.exit_with(print_version);
+ arser.add_argument("-V", "--verbose")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("output additional information to stdout or stderr");
+
arser.add_argument("--O1").nargs(0).required(false).default_value(false).help(
"Enable O1 optimize options");
.default_value(false)
.help("This will fuse operators to InstanceNorm operator");
+ arser.add_argument("--fuse_mean_with_mean")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will fuse two Mean operations when they follow one by one."
+ "This will fold them into one operation and merge reduction indices.");
+
+ arser.add_argument("--fuse_transpose_with_mean")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will fuse Mean operation with a preceding Transpose under certain conditions.");
+
arser.add_argument("--make_batchnorm_gamma_positive")
.nargs(0)
.required(false)
.default_value(false)
.help("This will fuse BatchNorm operators of pre-activations to Convolution operator");
+ arser.add_argument("--remove_fakequant")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will remove FakeQuant operators");
+
+ arser.add_argument("--remove_quantdequant")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will remove Quantize-Dequantize sequence");
+
arser.add_argument("--remove_redundant_reshape")
.nargs(0)
.required(false)
.default_value(false)
.help("This will replace channel-wise mul/add with DepthwiseConv2D operator");
+ arser.add_argument("--replace_sub_with_add")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will replace sub with add operator");
+
arser.add_argument("--resolve_customop_add")
.nargs(0)
.required(false)
.default_value(false)
.help("This will convert Custom(Matmul) to Matmul operator");
+ arser.add_argument("--resolve_customop_max_pool_with_argmax")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will convert Custom(MaxPoolWithArgmax) to equivalent set of operators");
+
arser.add_argument("--shuffle_weight_to_16x1float32")
.nargs(0)
.required(false)
.default_value(false)
.help("This will convert single input Pack to Reshape");
+ arser.add_argument("--substitute_padv2_to_pad")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will convert certain condition PadV2 to Pad");
+
arser.add_argument("--substitute_squeeze_to_reshape")
.nargs(0)
.required(false)
.default_value(false)
.help("This will convert certain condition Squeeze to Reshape");
+ arser.add_argument("--substitute_strided_slice_to_reshape")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will convert certain condition Strided_Slice to Reshape");
+
arser.add_argument("--substitute_transpose_to_reshape")
.nargs(0)
.required(false)
.help("Experimental: This will convert NCHW operators to NHWC under the assumption that "
"input model is NCHW.");
- arser.add_argument("--nchw_to_nhwc_preserve_input_shape")
+ arser.add_argument("--nchw_to_nhwc_input_shape")
.nargs(0)
.required(false)
.default_value(false)
- .help("Preserve the input shape of the model (argument for --convert_nchw_to_nhwc).");
+ .help("Convert the input shape of the model (argument for --convert_nchw_to_nhwc).");
- arser.add_argument("--nchw_to_nhwc_preserve_output_shape")
+ arser.add_argument("--nchw_to_nhwc_output_shape")
.nargs(0)
.required(false)
.default_value(false)
- .help("Preserve the output shape of the model (argument for --convert_nchw_to_nhwc).");
+ .help("Convert the output shape of the model (argument for --convert_nchw_to_nhwc).");
arser.add_argument("--transform_min_max_to_relu6")
.nargs(0)
.required(false)
.default_value(false)
- .help("Transform Minimum-Maximum pattern to Relu6 operator");
+ .help("Transform Minimum(6)-Maximum(0) pattern to Relu6 operator");
+
+ arser.add_argument("--transform_min_relu_to_relu6")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("Transform Minimum(6)-Relu pattern to Relu6 operator");
arser.add_argument("--mute_warnings")
.nargs(0)
.default_value(false)
.help("This will turn on profiling data generation.");
+ arser.add_argument("--change_outputs")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .required(false)
+ .help("Experimental: Change first subgraph output nodes to CSV names");
+
arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
}
catch (const std::runtime_error &err)
{
- std::cout << err.what() << std::endl;
+ std::cerr << err.what() << std::endl;
std::cout << arser;
return 255;
}
+ if (arser.get<bool>("--verbose"))
+ {
+ // The third parameter of setenv means REPLACE.
+ // If REPLACE is zero, it does not overwrite an existing value.
+ setenv("LUCI_LOG", "100", 0);
+ }
if (arser.get<bool>("--O1"))
{
options->enable(Algorithms::FuseBCQ);
options->enable(Algorithms::FuseBCQ);
if (arser.get<bool>("--fuse_instnorm"))
options->enable(Algorithms::FuseInstanceNorm);
+ if (arser.get<bool>("--fuse_mean_with_mean"))
+ options->enable(Algorithms::FuseMeanWithMean);
if (arser.get<bool>("--make_batchnorm_gamma_positive"))
options->enable(Algorithms::MakeBatchNormGammaPositive);
if (arser.get<bool>("--fuse_preactivation_batchnorm"))
options->enable(Algorithms::FusePreActivationBatchNorm);
+ if (arser.get<bool>("--fuse_transpose_with_mean"))
+ options->enable(Algorithms::FuseTransposeWithMean);
+ if (arser.get<bool>("--remove_fakequant"))
+ options->enable(Algorithms::RemoveFakeQuant);
+ if (arser.get<bool>("--remove_quantdequant"))
+ options->enable(Algorithms::RemoveQuantDequantSeq);
if (arser.get<bool>("--remove_redundant_reshape"))
options->enable(Algorithms::RemoveRedundantReshape);
if (arser.get<bool>("--remove_redundant_transpose"))
options->enable(Algorithms::RemoveUnnecessarySplit);
if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
+ if (arser.get<bool>("--replace_sub_with_add"))
+ options->enable(Algorithms::ReplaceSubWithAdd);
if (arser.get<bool>("--resolve_customop_add"))
options->enable(Algorithms::ResolveCustomOpAdd);
if (arser.get<bool>("--resolve_customop_batchmatmul"))
options->enable(Algorithms::ResolveCustomOpBatchMatMul);
if (arser.get<bool>("--resolve_customop_matmul"))
options->enable(Algorithms::ResolveCustomOpMatMul);
+ if (arser.get<bool>("--resolve_customop_max_pool_with_argmax"))
+ options->enable(Algorithms::ResolveCustomOpMaxPoolWithArgmax);
if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
options->enable(Algorithms::ShuffleWeightTo16x1Float32);
if (arser.get<bool>("--substitute_pack_to_reshape"))
options->enable(Algorithms::SubstitutePackToReshape);
+ if (arser.get<bool>("--substitute_padv2_to_pad"))
+ options->enable(Algorithms::SubstitutePadV2ToPad);
if (arser.get<bool>("--substitute_squeeze_to_reshape"))
options->enable(Algorithms::SubstituteSqueezeToReshape);
+ if (arser.get<bool>("--substitute_strided_slice_to_reshape"))
+ options->enable(Algorithms::SubstituteStridedSliceToReshape);
if (arser.get<bool>("--substitute_transpose_to_reshape"))
options->enable(Algorithms::SubstituteTransposeToReshape);
if (arser.get<bool>("--transform_min_max_to_relu6"))
options->enable(Algorithms::TransformMinMaxToRelu6Pass);
+ if (arser.get<bool>("--transform_min_relu_to_relu6"))
+ options->enable(Algorithms::TransformMinReluToRelu6Pass);
if (arser.get<bool>("--mute_warnings"))
settings->set(luci::UserSettings::Key::MuteWarnings, true);
if (arser.get<bool>("--convert_nchw_to_nhwc"))
{
options->enable(Algorithms::ConvertNCHWToNHWC);
- if (arser.get<bool>("--nchw_to_nhwc_preserve_input_shape"))
- options->param(AlgorithmParameters::NCHW_to_NHWC_preserve_input_shape, "true");
- if (arser.get<bool>("--nchw_to_nhwc_preserve_output_shape"))
- options->param(AlgorithmParameters::NCHW_to_NHWC_preserve_output_shape, "true");
+ if (arser.get<bool>("--nchw_to_nhwc_input_shape"))
+ options->param(AlgorithmParameters::NCHW_to_NHWC_input_shape, "true");
+ if (arser.get<bool>("--nchw_to_nhwc_output_shape"))
+ options->param(AlgorithmParameters::NCHW_to_NHWC_output_shape, "true");
+ }
+
+ // Change output nodes
+ bool change_outputs = false;
+ std::vector<std::string> new_outputs;
+ if (arser["--change_outputs"])
+ {
+ change_outputs = true;
+ auto csv_nodes = arser.get<std::string>("--change_outputs");
+ csv_tokenize(csv_nodes, new_outputs);
}
// Load model from the file
luci::Importer importer;
auto module = importer.importModule(circle_model);
+ if (change_outputs)
+ {
+ auto graph = module->graph(0);
+ luci::change_outputs(graph, new_outputs);
+ }
+
// call luci optimizations for module
optimizer.optimize(module.get());
return circle::TensorType_UINT8;
case circlechef::INT64:
return circle::TensorType_INT64;
+ case circlechef::STRING:
+ return circle::TensorType_STRING;
case circlechef::BOOL:
return circle::TensorType_BOOL;
case circlechef::INT16:
DATA_CHEF(UINT8, explicit, ExplicitDataChefFactory<uint8_t>)
DATA_CHEF(BOOL, explicit, ExplicitDataChefFactory<bool>)
DATA_CHEF(FLOAT32, explicit, ExplicitDataChefFactory<float>)
+DATA_CHEF(STRING, explicit, ExplicitDataChefFactory<std::string>)
DATA_CHEF(FLOAT32, gaussian, GaussianFloat32DataChefFactory)
DATA_CHEF(INT32, gaussian, GaussianInt32DataChefFactory)
DATA_CHEF(INT16, gaussian, GaussianInt16DataChefFactory)
static DataChefRegistry s64;
static DataChefRegistry fp32;
static DataChefRegistry u8;
+ static DataChefRegistry string;
static DataChefRegistry boolean;
static DataChefRegistry s16;
return fp32;
case circlechef::UINT8:
return u8;
+ case circlechef::STRING:
+ return string;
case circlechef::BOOL:
return boolean;
case circlechef::INT16:
#define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
data_chef_registry(::circlechef::TYPE) \
.add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
-#include <souschef/DataChef.def>
+#include "DataChef.def"
#undef DATA_CHEF
//
INT32 = 2;
UINT8 = 3;
INT64 = 4;
+ STRING = 5;
BOOL = 6;
INT16 = 7;
}
--- /dev/null
+operand {
+ name: "ifm"
+ shape { }
+ type: STRING
+}
+operand {
+ name: "constant"
+ type: STRING
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "Hello"
+ }
+}
+operand {
+ name: "ofm"
+ type: STRING
+ shape { }
+}
+operation {
+ type: "BatchMatMul"
+ input: "ifm"
+ input: "constant"
+ output: "ofm"
+ batch_matmul_options {
+ adjoint_lhs: false
+ adjoint_rhs: false
+ }
+}
+input: "ifm"
+output: "ofm"
target_link_libraries(circledump mio_circle)
target_link_libraries(circledump safemain)
target_link_libraries(circledump flatbuffers)
+
+install(TARGETS circledump DESTINATION bin)
os << std::boolalpha;
os << "adjoint_lhs(" << params->adjoint_lhs() << ") ";
os << "adjoint_rhs(" << params->adjoint_rhs() << ") ";
+ os << std::noboolalpha;
os << std::endl;
}
}
os << std::boolalpha;
os << "align_corners(" << resize_params->align_corners() << ")";
os << "half_pixel_centers(" << resize_params->half_pixel_centers() << ")";
+ os << std::noboolalpha;
os << std::endl;
}
}
os << " ";
os << std::boolalpha;
os << "align_corners(" << resize_params->align_corners() << ")";
+ os << std::noboolalpha;
os << std::endl;
}
}
template <typename T> std::vector<T> as_index_vector(const flatbuffers::Vector<T> *flat_array)
{
+ if (flat_array == nullptr)
+ {
+ throw std::runtime_error("flat array is nullptr");
+ }
+
std::vector<T> ret(flat_array->Length());
for (uint32_t i = 0; i < flat_array->Length(); i++)
{
const CircleOperators_t *_operators{nullptr};
const CircleMetadata_t *_metadata{nullptr};
- uint32_t _subgraph_index;
+ uint32_t _subgraph_index = 0;
std::string _subgraph_name;
std::vector<const circle::OperatorCode *> _op_codes;
std::vector<int32_t> _inputs;
std::vector<int32_t> _outputs;
- circle::DataFormat _data_format;
+ circle::DataFormat _data_format = circle::DataFormat::DataFormat_CHANNELS_FIRST;
};
} // namespace circleread
add_custom_command(OUTPUT ${INPUT_HDF5_FILE} ${EXPECTED_HDF5_FILE}
COMMAND $<TARGET_FILE:testDataGenerator> --input_data ${INPUT_HDF5_FILE} --expected_data ${EXPECTED_HDF5_FILE} ${MODEL_FILE}
DEPENDS $<TARGET_FILE:testDataGenerator> ${MODEL_FILE} ${TC_DIRECTORY}
- COMMENT "Generate ${INPUT_HDF5_FILE} and ${EXPECTED_HDF5_FILE}"
+ COMMENT "Generate input.h5 and expected.h5 in ${NNPKG_FILE}/metadata/tc"
)
list(APPEND TEST_DEPS ${INPUT_HDF5_FILE} ${EXPECTED_HDF5_FILE})
endif()
tcgenerate(AddN_000)
tcgenerate(Add_001) # runtime doesn't support
tcgenerate(Add_U8_000)
+tcgenerate(Add_STR_000) # STRING is not supported
+tcgenerate(Add_STR_001) # STRING is not supported
tcgenerate(All_000)
tcgenerate(ArgMin_000)
tcgenerate(ArgMin_001)
tcgenerate(Gather_000)
tcgenerate(GatherNd_000)
tcgenerate(GatherNd_001)
-tcgenerate(If_000)
-tcgenerate(If_001)
tcgenerate(L2Pool2D_U8_000)
tcgenerate(Log_000)
tcgenerate(MatMul_000)
tcgenerate(MatrixBandPart_000)
tcgenerate(MatrixDiag_000)
tcgenerate(MatrixSetDiag_000)
-tcgenerate(MaxPoolWithArgMax_000)
-tcgenerate(MaxPoolWithArgMax_001)
-tcgenerate(MaxPoolWithArgMax_002)
+tcgenerate(MaxPoolWithArgmax_000)
+tcgenerate(MaxPoolWithArgmax_001)
+tcgenerate(MaxPoolWithArgmax_002)
tcgenerate(Mean_dynamic_000) # TestDataGenerator does not support unknown dimension
tcgenerate(Mean_dynamic_001) # TestDataGenerator does not support unknown dimension
tcgenerate(Mean_U8_dynamic_000) # TestDataGenerator does not support unknown dimension
tcgenerate(Mul_U8_000)
tcgenerate(Neg_000)
tcgenerate(Net_BroadcastTo_AddV2_001) # luci-interpreter doesn't support custom operator
+tcgenerate(Net_Conv_FakeQuant_000) # luci-interpreter doesn't support FakeQuant yet
+tcgenerate(Net_Conv_QuantDequant_000) # luci-interpreter doesn't support Quantize/Dequantize yet
tcgenerate(Net_Dangle_001)
tcgenerate(Net_ZeroDim_001) # luci-interpreter doesn't support zero dim
tcgenerate(OneHot_000)
tcgenerate(Pack_000)
tcgenerate(Pack_U8_000)
tcgenerate(PadV2_000)
+tcgenerate(Quantize_000) # runtime and luci-interpreter doesn't support Quantize op yet
tcgenerate(Range_000)
tcgenerate(Rank_000)
tcgenerate(ReduceAny_000)
tcgenerate(Unique_U8_001)
tcgenerate(Where_000)
tcgenerate(Where_001)
-tcgenerate(While_000)
-tcgenerate(While_001)
-tcgenerate(While_002)
-tcgenerate(While_003)
+tcgenerate(While_000) # Needs luci-interpreter int32_t support for ADD, EQUAL
+tcgenerate(While_001) # Needs luci-interpreter int32_t support for ADD, EQUAL
+tcgenerate(While_002) # Needs luci-interpreter int32_t support for ADD, EQUAL
+tcgenerate(While_003) # Needs luci-interpreter int32_t support for ADD, EQUAL, and dynamic shape for WHILE
tcgenerate(YUV_TO_RGB_000)
tcgenerate(YUV_TO_RGB_U8_000)
tcgenerate(ZerosLike_000)
}
}
+template <> void geneate_random_data<bool>(std::mt19937 &gen, void *data, uint32_t size)
+{
+ std::normal_distribution<float> distrib(0, 2); // mean(0), stddev(2)
+ for (uint32_t i = 0; i < size; i++)
+ {
+ static_cast<bool *>(data)[i] = distrib(gen) >= 0 ? true : false;
+ }
+}
+
void fill_random_data(void *data, uint32_t size, loco::DataType dtype, uint32_t seed)
{
std::mt19937 gen(seed); // standard mersenne_twister_engine seeded with rd()
case loco::DataType::FLOAT32:
geneate_random_data<float>(gen, data, size);
break;
- default:
+ case loco::DataType::BOOL:
+ geneate_random_data<bool>(gen, data, size);
break;
+ default:
+ throw std::runtime_error("NYI data type.");
}
}
std::random_device rd; // used to obtain a seed for the random number engine
uint32_t input_index = 0;
- for (uint32_t g = 0; g < circle_model->subgraphs()->size(); g++)
+ // TODO remove indentation
{
- const auto input_nodes = loco::input_nodes(module->graph(g));
+ // NOTE we only need to prepare data for main graph (subgraph 0) as
+ // other subgraphs are invoked by the main graph
+ const auto input_nodes = loco::input_nodes(module->graph(0));
for (const auto &node : input_nodes)
{
const auto *input_node = dynamic_cast<const luci::CircleInput *>(node);
// dump output data into hdf5 file
uint32_t output_index = 0;
- for (uint32_t g = 0; g < circle_model->subgraphs()->size(); g++)
+ // TODO remove indentation
{
- const auto output_nodes = loco::output_nodes(module->graph(g));
+ const auto output_nodes = loco::output_nodes(module->graph(0));
for (const auto &node : output_nodes)
{
const auto *output_node = dynamic_cast<const luci::CircleOutput *>(node);
public:
FileLoader(const FileLoader &) = delete;
- FileLoader(FileLoader &&) = delete;
+ FileLoader &operator=(const FileLoader &) = delete;
public:
DataBuffer load(void) const
std::ifstream file(_path, std::ios::binary | std::ios::in);
if (!file.good())
{
- std::string errmsg = "ERROR: Failed to open file: " + _path;
+ std::string errmsg = "Failed to open file: " + _path;
throw std::runtime_error(errmsg.c_str());
}
file.read(data.data(), fileSize);
if (file.fail())
{
- std::string errmsg = "ERROR: Failed to read file: " + _path;
+ std::string errmsg = "Failed to read file: " + _path;
throw std::runtime_error(errmsg.c_str());
}
target_link_libraries(loco PUBLIC nncc_coverage)
# Q. HOW TO MAKE DEV PACKAGE(?)
install(TARGETS loco DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
// WARNING the size of Bool may vary for NN frameworks
// TODO we need to find a way to resolve this issue
BOOL, // Boolean
+
+ // WARNING STRING is NOT fully supported yet
+ STRING, // String
};
} // namespace loco
#include <cassert>
#include <cstdint>
+#include <stdexcept>
namespace loco
{
using Type = uint8_t;
};
+template <> struct DataTypeImpl<DataType::STRING>
+{
+ // Use C++ std::string type for STRING
+ using Type = std::string;
+};
+
/**
* @brief Returns the size of the data type.
* @note If you need the size at compile time, use `sizeof(typename DataTypeImpl<DT>::Type)`.
return sizeof(DataTypeImpl<DataType::FLOAT64>::Type);
case DataType::BOOL:
return sizeof(DataTypeImpl<DataType::BOOL>::Type);
+ case DataType::STRING:
+ // STRING is variable length. Cannot decide size by type
+ throw std::runtime_error("Invalid size call with STRING type");
default:
// TODO Support remaining data types.
assert(false);
std::vector<Dimension> _dims;
};
-template <unsigned N> struct FixedArity
+template <uint32_t N> struct FixedArity
{
template <typename Base> class Mixin : public virtual Base
{
virtual ~Mixin() = default;
public:
- unsigned arity(void) const final { return N; }
+ uint32_t arity(void) const final { return N; }
Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
protected:
// This API allows inherited classes to access "_args" field.
- Use *at(unsigned n) const { return _args.at(n).get(); }
+ Use *at(uint32_t n) const { return _args.at(n).get(); }
private:
std::array<std::unique_ptr<Use>, N> _args{};
void execute_node(loco::Pull *pull)
{
-// TODO Remove deprecated code
-#if 0
- validate(annot_data(pull), "Data for Pull is not ready");
-
- validate(annot_domain(pull) == loco::Domain::Tensor, "Domain for Pull is not Tensor");
-
- // DO NOTHING
-#endif
-
auto input_data = user_data(pull);
validate(input_data, "Input not ready");
auto pull_data = locomotiv::make_data(pull_buf);
locomotiv::user_data(pull, std::move(pull_data));
-// The behavior of Pull is now consistent with that of other nodes.
-// - annot_data and annot_domain is available after evaluating that "pull" node.
-// TODO Remove this
-#if 0
- // Domain not ready yet
- ASSERT_ANY_THROW(locomotiv::NodeExecution::get().run(pull));
-
- // Set Domain
- locomotiv::annot_domain(pull, loco::Domain::Tensor);
-#endif
+ // The behavior of Pull is now consistent with that of other nodes.
+ // - annot_data and annot_domain is available after evaluating that "pull" node.
// Valid run
ASSERT_NO_THROW(locomotiv::NodeExecution::get().run(pull));
#include "NodeExecution.h"
-// TODO Remove deprecated code
-#if 0
-#include "NodeDataImpl.h"
-#include "NodeDomain.h"
-#include "Validation.h"
-
-#include <nncc/core/ADT/tensor/Shape.h>
-#include <nncc/core/ADT/tensor/Buffer.h>
-#include <nncc/core/ADT/tensor/IndexEnumerator.h>
-#include <nncc/core/ADT/tensor/LexicalLayout.h>
-
-using nncc::core::ADT::tensor::IndexEnumerator;
-using nncc::core::ADT::tensor::LexicalLayout;
-using nncc::core::ADT::tensor::make_buffer;
-
-#include <cassert>
-#include <stdexcept>
-#endif
-
namespace
{
void NodeExecution::execute(loco::ReLU6 *relu6)
{
-// TODO Remove deprecated code
-#if 0
- auto input_data = annot_data(relu6->input());
-
- validate(input_data, "Input not ready");
- validate(annot_domain(relu6->input()) != loco::Domain::Unknown,
- "Input domain of ReLU is Unknown");
-
- std::unique_ptr<NodeData> relu6_data = nullptr;
-
- switch (input_data->dtype())
- {
- case loco::DataType::FLOAT32:
- {
- auto input_bufptr = input_data->as_f32_bufptr();
- auto *shape = input_data->shape();
- auto relu6_buf = make_buffer<float, LexicalLayout>(*shape);
-
- for (IndexEnumerator e{*shape}; e.valid(); e.advance())
- {
- const auto &index = e.current();
- relu6_buf.at(index) = relu6_ew(input_bufptr->at(index));
- }
-
- relu6_data = make_data(relu6_buf);
- break;
- }
- default:
- throw std::runtime_error("NYI for this DataType");
- }
-
- assert(relu6_data != nullptr);
- annot_data(relu6, std::move(relu6_data));
- annot_domain(relu6, annot_domain(relu6->input()));
-#endif
-
struct Func final : public UnaryFunc
{
float apply(float v) const final { return relu6_ew(v); }
target_link_libraries(luci_eval_driver PRIVATE luci_lang)
target_link_libraries(luci_eval_driver PRIVATE luci_interpreter)
target_link_libraries(luci_eval_driver PRIVATE safemain)
+
+install(TARGETS luci_eval_driver DESTINATION bin)
set(LUCI_INTERPRETER_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include")
set(LUCI_INTERPRETER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src")
+if (NOT LUCI_INTERPRETER_PAL_DIR)
+ set(LUCI_INTERPRETER_PAL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/pal/linux")
+endif()
add_subdirectory(src)
--- /dev/null
+macro(initialize_pal)
+ nnas_find_package(TensorFlowSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowEigenSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowRuySource EXACT 2.3.0 QUIET)
+
+ if (NOT TensorFlowSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: TensorFlow not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowGEMMLowpSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: gemmlowp not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowEigenSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: Eigen not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowRuySource_FOUND)
+ message(STATUS "Skipping luci-interpreter: Ruy not found")
+ return()
+ endif ()
+
+ find_package(Threads REQUIRED)
+
+ set(PAL_INITIALIZED TRUE)
+endmacro()
+
+macro(add_pal_to_target TGT)
+ target_include_directories(${TGT} PRIVATE "${PAL}")
+ target_include_directories(${TGT} SYSTEM PRIVATE
+ "${TensorFlowRuySource_DIR}"
+ "${TensorFlowGEMMLowpSource_DIR}"
+ "${TensorFlowEigenSource_DIR}"
+ "${TensorFlowSource_DIR}")
+ target_include_directories(${TGT} PRIVATE ${LUCI_INTERPRETER_PAL_DIR})
+
+ # TODO put it back, I changed my mind.
+ # instead add sources with visitors in this library
+ set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+ add_library(luci_interpreter_linux_pal STATIC ${PAL_SOURCES})
+ set_target_properties(luci_interpreter_linux_pal PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ target_include_directories(luci_interpreter_linux_pal SYSTEM PRIVATE "${TensorFlowSource_DIR}")
+
+ target_link_libraries(${TGT} PRIVATE Threads::Threads luci_interpreter_linux_pal)
+endmacro()
--- /dev/null
+macro(initialize_pal)
+ nnas_find_package(TensorFlowSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowEigenSource EXACT 2.3.0 QUIET)
+ nnas_find_package(TensorFlowRuySource EXACT 2.3.0 QUIET)
+
+ if (NOT TensorFlowSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: TensorFlow not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowGEMMLowpSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: gemmlowp not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowEigenSource_FOUND)
+ message(STATUS "Skipping luci-interpreter: Eigen not found")
+ return()
+ endif ()
+
+ if (NOT TensorFlowRuySource_FOUND)
+ message(STATUS "Skipping luci-interpreter: Ruy not found")
+ return()
+ endif ()
+ #find_package(Threads REQUIRED)
+
+ set(PAL_INITIALIZED TRUE)
+endmacro()
+
+macro(add_pal_to_target TGT)
+ target_include_directories(${TGT} PRIVATE "${PAL}")
+ target_include_directories(${TGT} SYSTEM PRIVATE
+ "${TensorFlowRuySource_DIR}"
+ "${TensorFlowGEMMLowpSource_DIR}"
+ "${TensorFlowEigenSource_DIR}"
+ "${TensorFlowSource_DIR}")
+ target_include_directories(${TGT} PRIVATE ${LUCI_INTERPRETER_PAL_DIR})
+
+ # TODO put it back, I changed my mind.
+ # instead add sources with visitors in this library
+ set(PAL_SOURCES ${TensorFlowSource_DIR}/tensorflow/lite/kernels/internal/quantization_util.cc)
+ add_library(luci_interpreter_mcu_pal STATIC ${PAL_SOURCES})
+ set_target_properties(luci_interpreter_mcu_pal PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ target_include_directories(luci_interpreter_mcu_pal SYSTEM PRIVATE "${TensorFlowSource_DIR}")
+
+ target_link_libraries(${TGT} PRIVATE luci_interpreter_mcu_pal)
+ #target_link_libraries(${TGT} PRIVATE Threads::Threads luci_interpreter_mcu_pal)
+endmacro()
-nnas_find_package(TensorFlowSource EXACT 2.3.0 QUIET)
-nnas_find_package(TensorFlowGEMMLowpSource EXACT 2.3.0 QUIET)
-nnas_find_package(TensorFlowEigenSource EXACT 2.3.0 QUIET)
-nnas_find_package(TensorFlowRuySource EXACT 2.3.0 QUIET)
+include(${LUCI_INTERPRETER_PAL_DIR}/pal.cmake)
-if (NOT TensorFlowSource_FOUND)
- message(STATUS "Skipping luci-interpreter: TensorFlow not found")
- return()
-endif ()
+initialize_pal()
-if (NOT TensorFlowGEMMLowpSource_FOUND)
- message(STATUS "Skipping luci-interpreter: gemmlowp not found")
+if (NOT PAL_INITIALIZED)
return()
-endif ()
+endif()
-if (NOT TensorFlowEigenSource_FOUND)
- message(STATUS "Skipping luci-interpreter: Eigen not found")
- return()
-endif ()
-
-if (NOT TensorFlowRuySource_FOUND)
- message(STATUS "Skipping luci-interpreter: Ruy not found")
- return()
-endif ()
+message(STATUS "LUCI INTERPRETER BEGIN")
add_subdirectory(core)
+message(STATUS "LUCI INTERPRETER CORE")
add_subdirectory(kernels)
+message(STATUS "LUCI INTERPRETER KERNELS")
add_subdirectory(loader)
+message(STATUS "LUCI INTERPRETER LOADER")
+
+message(STATUS "LUCI INTERPTER INITALIZED")
set(SOURCES
"${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/Interpreter.h"
PRIVATE nncc_common)
install(TARGETS luci_interpreter DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
#include <luci/IR/AttrPadding.h>
#include <luci/IR/AttrFusedActFunc.h>
+#include <luci/IR/AttrMirrorPadMode.h>
#include <luci_interpreter/core/DataType.h>
#include <cstdint>
// Inject commonly used types into `luci_interpreter` namespace for convenience.
using Activation = luci::FusedActFunc;
using Padding = luci::Padding;
+using MirrorPadMode = luci::MirrorPadMode;
struct AddParams
{
float beta;
};
+struct MirrorPadParams
+{
+ MirrorPadMode mode;
+};
+
struct MulParams
{
Activation activation;
AveragePool2D.cpp
BatchToSpaceND.h
BatchToSpaceND.cpp
+ Cast.h
+ Cast.cpp
Concatenation.h
Concatenation.cpp
Conv2D.h
Mean.cpp
Minimum.h
Minimum.cpp
+ MirrorPad.h
+ MirrorPad.cpp
Mul.h
Mul.cpp
Neg.h
Pack.cpp
Pad.h
Pad.cpp
+ PadV2.h
+ PadV2.cpp
Pow.h
Pow.cpp
- Prelu.h
- Prelu.cpp
+ PRelu.h
+ PRelu.cpp
Relu.h
Relu.cpp
Relu6.h
ResizeBilinear.cpp
ResizeNearestNeighbor.h
ResizeNearestNeighbor.cpp
- Reverse.h
- Reverse.cpp
+ ReverseV2.h
+ ReverseV2.cpp
Rsqrt.h
Rsqrt.cpp
Slice.h
StridedSlice.cpp
Sqrt.h
Sqrt.cpp
+ Square.h
+ Square.cpp
SquaredDifference.h
SquaredDifference.cpp
Squeeze.h
TransposeConv.h
TransposeConv.cpp
Unpack.h
- Unpack.cpp)
+ Unpack.cpp
+ While.h
+ While.cpp)
list(APPEND SOURCES
BinaryOpCommon.h
ArgMax.test.cpp
AveragePool2D.test.cpp
BatchToSpaceND.test.cpp
+ Cast.test.cpp
Concatenation.test.cpp
Conv2D.test.cpp
DepthToSpace.test.cpp
NotEqual.test.cpp
Pack.test.cpp
Pad.test.cpp
+ PadV2.test.cpp
Pow.test.cpp
- Prelu.test.cpp
+ PRelu.test.cpp
Relu.test.cpp
Relu6.test.cpp
Reshape.test.cpp
ResizeBilinear.test.cpp
ResizeNearestNeighbor.test.cpp
- Reverse.test.cpp
+ ReverseV2.test.cpp
Rsqrt.test.cpp
Slice.test.cpp
Softmax.test.cpp
Split.test.cpp
StridedSlice.test.cpp
Sqrt.test.cpp
+ Square.test.cpp
SquaredDifference.test.cpp
Squeeze.test.cpp
Sub.test.cpp
Tanh.test.cpp
Transpose.test.cpp
TransposeConv.test.cpp
- Unpack.test.cpp)
+ Unpack.test.cpp
+ While.test.cpp)
list(APPEND TEST_SOURCES TestUtils.h TestUtils.cpp)
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Cast.h"
+#include "kernels/Utils.h"
+
+namespace
+{
+
+using namespace luci_interpreter;
+using namespace luci_interpreter::kernels;
+
+template <typename InT, typename OutT>
+void cast_data(const InT *in_data, OutT *out_data, uint32_t elements_count)
+{
+ std::transform(in_data, in_data + elements_count, out_data,
+ [](InT a) { return static_cast<OutT>(a); });
+}
+
+template <typename InT> void cast_from_pointer_to_tensor(const InT *in_data, Tensor *out_tensor)
+{
+ auto const out_type = out_tensor->element_type();
+ auto const elements_count = out_tensor->shape().num_elements();
+
+ switch (out_type)
+ {
+ case loco::DataType::U8:
+ cast_data(in_data, getTensorData<uint8_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::U16:
+ cast_data(in_data, getTensorData<uint16_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::U32:
+ cast_data(in_data, getTensorData<uint32_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::U64:
+ cast_data(in_data, getTensorData<uint64_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::S8:
+ cast_data(in_data, getTensorData<int8_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::S16:
+ cast_data(in_data, getTensorData<int16_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::S32:
+ cast_data(in_data, getTensorData<int32_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::S64:
+ cast_data(in_data, getTensorData<int64_t>(out_tensor), elements_count);
+ break;
+ case loco::DataType::FLOAT32:
+ cast_data(in_data, getTensorData<float>(out_tensor), elements_count);
+ break;
+ case loco::DataType::BOOL:
+ cast_data(in_data, getTensorData<bool>(out_tensor), elements_count);
+ break;
+ default:
+ throw std::runtime_error("Unsupported output type.");
+ }
+}
+
+void cast_from_tensor_to_tensor(const Tensor *in_tensor, Tensor *out_tensor)
+{
+ auto in_type = in_tensor->element_type();
+
+ switch (in_type)
+ {
+ case loco::DataType::U8:
+ cast_from_pointer_to_tensor(getTensorData<uint8_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::U16:
+ cast_from_pointer_to_tensor(getTensorData<uint16_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::U32:
+ cast_from_pointer_to_tensor(getTensorData<uint32_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::U64:
+ cast_from_pointer_to_tensor(getTensorData<uint64_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::S8:
+ cast_from_pointer_to_tensor(getTensorData<int8_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::S16:
+ cast_from_pointer_to_tensor(getTensorData<int16_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::S32:
+ cast_from_pointer_to_tensor(getTensorData<int32_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::S64:
+ cast_from_pointer_to_tensor(getTensorData<int64_t>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::FLOAT32:
+ cast_from_pointer_to_tensor(getTensorData<float>(in_tensor), out_tensor);
+ break;
+ case loco::DataType::BOOL:
+ cast_from_pointer_to_tensor(getTensorData<bool>(in_tensor), out_tensor);
+ break;
+ default:
+ throw std::runtime_error("Unsupported input type.");
+ }
+}
+
+} // namespace
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+Cast::Cast(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
+
+void Cast::configure()
+{
+ LUCI_INTERPRETER_CHECK(input()->element_type() != loco::DataType::Unknown);
+ LUCI_INTERPRETER_CHECK(output()->element_type() != loco::DataType::Unknown);
+
+ const Shape &shape = input()->shape();
+ output()->resize(shape);
+}
+
+void Cast::execute() const
+{
+ assert(input()->shape().num_elements() == output()->shape().num_elements());
+
+ cast_from_tensor_to_tensor(input(), output());
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_CAST_H
+#define LUCI_INTERPRETER_KERNELS_CAST_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class Cast : public Kernel
+{
+public:
+ Cast(const Tensor *input, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_CAST_H
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Cast.h"
+#include "kernels/TestUtils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+template <typename T1, typename T2>
+void Check(std::initializer_list<int32_t> shape, std::initializer_list<T1> input_data,
+ std::initializer_list<T2> output_data)
+{
+ constexpr DataType input_type = getElementType<T1>();
+ constexpr DataType output_type = getElementType<T2>();
+
+ Tensor input_tensor = makeInputTensor<input_type>(shape, input_data);
+ Tensor output_tensor = makeOutputTensor(output_type);
+
+ Cast kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<T2>(output_tensor), ::testing::ElementsAreArray(output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), shape);
+}
+
+template <typename T> class CastTest : public ::testing::Test
+{
+};
+
+using DataTypes = ::testing::Types<uint8_t, int32_t, int64_t>;
+TYPED_TEST_CASE(CastTest, DataTypes);
+
+TYPED_TEST(CastTest, FloatToInt)
+{
+ Check<float, TypeParam>(/*shape=*/{1, 1, 1, 4},
+ /*input_data=*/
+ {
+ 1.43f, 9.99f, 7.0f, 3.12f, //
+ },
+ /*output_data=*/
+ {
+ 1, 9, 7, 3, //
+ });
+ Check<TypeParam, TypeParam>(/*shape=*/{1, 1, 1, 4},
+ /*input_data=*/
+ {
+ 1, 9, 7, 3, //
+ },
+ /*output_data=*/
+ {
+ 1, 9, 7, 3, //
+ });
+}
+
+TEST(CastTest, UnsupportedType_NEG)
+{
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 1, 2, 4}, {
+ 1, 2, 7, 8, //
+ 1, 9, 7, 3, //
+ });
+ Tensor output_tensor = makeOutputTensor(DataType::Unknown);
+
+ Cast kernel(&input_tensor, &output_tensor);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
LUCI_INTERPRETER_CHECK(input()->shape().num_dims() == 4);
LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
LUCI_INTERPRETER_CHECK(gamma()->element_type() == input()->element_type());
+ LUCI_INTERPRETER_CHECK(gamma()->shape().num_dims() == 1);
+ LUCI_INTERPRETER_CHECK(gamma()->shape().dim(0) == input()->shape().dim(3) ||
+ gamma()->shape().dim(0) == 1);
LUCI_INTERPRETER_CHECK(beta()->element_type() == input()->element_type());
+ LUCI_INTERPRETER_CHECK(beta()->shape().num_dims() == 1);
+ LUCI_INTERPRETER_CHECK(beta()->shape().dim(0) == input()->shape().dim(3) ||
+ beta()->shape().dim(0) == 1);
output()->resize(input()->shape());
}
const int32_t channels = tflite::MatchingDim(input_shape, 3, output_shape, 3);
const float *input_data = getTensorData<float>(input());
const float *gamma_data = getTensorData<float>(gamma());
+ auto gamma_shape = getTensorShape(gamma());
+ bool single_gamma = gamma_shape.DimensionsCount() == 1 && gamma_shape.Dims(0) == 1;
const float *beta_data = getTensorData<float>(beta());
+ auto beta_shape = getTensorShape(beta());
+ bool single_beta = beta_shape.DimensionsCount() == 1 && beta_shape.Dims(0) == 1;
float *output_data = getTensorData<float>(output());
for (int32_t batch = 0; batch < batches; batch++)
{
double mean = sum / size;
double var = square_sum / size - mean * mean;
- double gamma = gamma_data[channel];
- double beta = beta_data[channel];
+ double gamma = single_gamma ? gamma_data[0] : gamma_data[channel];
+ double beta = single_beta ? beta_data[0] : beta_data[channel];
double a = gamma / (std::sqrt(var + params().epsilon));
double b = -mean * a + beta;
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 2, 1}));
}
+TEST(InstanceNormTest, Single_gamma_beta)
+{
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 1, 2}, {1, 1, 1, 1});
+ Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1});
+ Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({1}, {2});
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ InstanceNormParams params{};
+ params.epsilon = 0.1f;
+ params.activation = Activation::NONE;
+
+ InstanceNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
+ kernel.configure();
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear({2, 2, 2, 2}));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 2, 1, 2}));
+}
+
+TEST(InstanceNormTest, Wrong_gamma_beta_dim_NEG)
+{
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 1, 2}, {1, 1, 1, 1});
+ Tensor gamma_tensor = makeInputTensor<DataType::FLOAT32>({3}, {1, 1, 1});
+ Tensor beta_tensor = makeInputTensor<DataType::FLOAT32>({3}, {2, 2, 2});
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ InstanceNormParams params{};
+ params.epsilon = 0.1f;
+ params.activation = Activation::NONE;
+
+ InstanceNorm kernel(&input_tensor, &gamma_tensor, &beta_tensor, &output_tensor, params);
+ EXPECT_ANY_THROW(kernel.configure());
+}
+
} // namespace
} // namespace kernels
} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/MirrorPad.h"
+
+#include "kernels/Utils.h"
+
+#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+MirrorPad::MirrorPad(const Tensor *input, const Tensor *paddings, Tensor *output,
+ const MirrorPadParams ¶ms)
+ : KernelWithParams<MirrorPadParams>({input, paddings}, {output}, params)
+{
+}
+
+void MirrorPad::configure()
+{
+ const Shape &input_shape = input()->shape();
+ const int num_dims = input_shape.num_dims();
+
+ if (num_dims > 4)
+ throw std::runtime_error("Unsupported number of dimensions.");
+
+ assert(output()->element_type() == input()->element_type());
+ assert(paddings()->element_type() == DataType::S32);
+ // Paddings shape should be [N, 2].
+ assert(paddings()->shape().num_dims() == 2);
+ assert(paddings()->shape().dim(0) == num_dims);
+ assert(paddings()->shape().dim(1) == 2);
+
+ Shape output_shape(num_dims);
+ const auto *paddings_data = getTensorData<int32_t>(paddings());
+ for (int i = 0; i < num_dims; ++i)
+ {
+ const int32_t padding_before = paddings_data[i * 2];
+ const int32_t padding_after = paddings_data[i * 2 + 1];
+ assert(padding_before >= 0 && padding_after >= 0);
+ output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
+ }
+
+ output()->resize(output_shape);
+}
+
+void MirrorPad::execute() const
+{
+ const int num_dims = input()->shape().num_dims();
+
+ tflite::PadParams params{};
+ params.left_padding_count = num_dims;
+ params.right_padding_count = num_dims;
+
+ const auto *paddings_data = getTensorData<int32_t>(paddings());
+ for (int i = num_dims - 1; i >= 0; --i)
+ {
+ params.left_padding[i] = paddings_data[i * 2];
+ params.right_padding[i] = paddings_data[i * 2 + 1];
+ }
+
+ switch (input()->element_type())
+ {
+ case DataType::FLOAT32:
+ {
+ const float pad_value = 0;
+
+ // NOTE: this implementation only obtains min-max values for quantization
+ // TODO: calculate proper inference values
+ tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<float>(input()),
+ &pad_value, getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ }
+ case DataType::U8:
+ {
+ // NOTE: this implementation only obtains min-max values for quantization
+ // TODO: calculate proper inference values
+ assert(output()->zero_point() >= std::numeric_limits<uint8_t>::min());
+ assert(output()->zero_point() <= std::numeric_limits<uint8_t>::max());
+ const auto pad_value = static_cast<uint8_t>(output()->zero_point());
+ tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
+ &pad_value, getTensorShape(output()),
+ getTensorData<uint8_t>(output()));
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_MIRROR_PAD_H
+#define LUCI_INTERPRETER_KERNELS_MIRROR_PAD_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class MirrorPad : public KernelWithParams<MirrorPadParams>
+{
+public:
+ MirrorPad(const Tensor *input, const Tensor *paddings, Tensor *output,
+ const MirrorPadParams ¶ms);
+
+ const Tensor *input() const { return _inputs[0]; }
+ const Tensor *paddings() const { return _inputs[1]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_MIRROR_PAD_H
* limitations under the License.
*/
-#include "kernels/Prelu.h"
+#include "kernels/PRelu.h"
#include "kernels/BinaryOpCommon.h"
#include "kernels/Utils.h"
namespace kernels
{
-Prelu::Prelu(const Tensor *input, const Tensor *alpha, Tensor *output)
+PRelu::PRelu(const Tensor *input, const Tensor *alpha, Tensor *output)
: Kernel({input, alpha}, {output})
{
}
-Prelu::~Prelu()
+PRelu::~PRelu()
{
// Destructor declared to delete vector of alpha quantized data properly
}
-void Prelu::configure()
+void PRelu::configure()
{
LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
LUCI_INTERPRETER_CHECK(alpha()->element_type() == output()->element_type());
{
LUCI_INTERPRETER_CHECK(alpha()->zero_points()[channel] == 0);
}
- // Prelu specific checks for CWQ
+ // PRelu specific checks for CWQ
LUCI_INTERPRETER_CHECK(alpha()->quantized_dimension() == alpha()->shape().num_dims() - 1);
LUCI_INTERPRETER_CHECK(static_cast<int32_t>(alpha()->scales().size()) ==
alpha()->shape().dim(alpha()->quantized_dimension()));
output()->resize(calculateShapeForBroadcast(input()->shape(), alpha()->shape()));
}
-void Prelu::execute() const
+void PRelu::execute() const
{
switch (input()->element_type())
{
}
}
-void Prelu::evalFloat() const
+void PRelu::evalFloat() const
{
const auto input_data = getTensorData<float>(input());
const auto alpha_data = getTensorData<float>(alpha());
const auto size = getTensorShape(input()).FlatSize();
auto output_data = getTensorData<float>(output());
- auto PreluFunc = [](float input, float alpha) { return input >= 0.0 ? input : input * alpha; };
+ auto PReluFunc = [](float input, float alpha) { return input >= 0.0 ? input : input * alpha; };
if (input()->shape() != alpha()->shape())
{
tflite::reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
getTensorShape(input()), getTensorData<float>(input()), getTensorShape(alpha()),
getTensorData<float>(alpha()), getTensorShape(output()), getTensorData<float>(output()),
- PreluFunc);
+ PReluFunc);
}
else
{
}
}
-void Prelu::evalQuantized() const
+void PRelu::evalQuantized() const
{
tflite::PreluParams op_params{};
}
}
-static inline int16_t evalElemS16Prelu(int16_t input_val, int16_t alpha_val,
+static inline int16_t evalElemS16PRelu(int16_t input_val, int16_t alpha_val,
const ChannelQuantMultipliers &identity_mult,
const ChannelQuantMultipliers &alpha_mult)
{
return clamped_output;
}
-void Prelu::evalQuantizedS16() const
+void PRelu::evalQuantizedS16() const
{
// Note that this kernel assumes alpha is CWQ
tflite::RuntimeShape input_shape = getTensorShape(input());
offset += quant_channel;
output_data[offset] =
- evalElemS16Prelu(input_data[offset], alpha_data[quant_channel], pos_mult, neg_mult);
+ evalElemS16PRelu(input_data[offset], alpha_data[quant_channel], pos_mult, neg_mult);
}
}
class ChannelQuantMultipliers;
-class Prelu : public Kernel
+class PRelu : public Kernel
{
public:
- Prelu(const Tensor *input, const Tensor *alpha, Tensor *output);
+ PRelu(const Tensor *input, const Tensor *alpha, Tensor *output);
- ~Prelu();
+ ~PRelu();
const Tensor *input() const { return _inputs[0]; }
const Tensor *alpha() const { return _inputs[1]; }
* limitations under the License.
*/
-#include "kernels/Prelu.h"
+#include "kernels/PRelu.h"
#include "kernels/TestUtils.h"
namespace luci_interpreter
Tensor alpha_tensor = makeInputTensor<element_type>(alpha_shape, alpha_data);
Tensor output_tensor = makeOutputTensor(element_type);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
}
-TEST(PreluTest, FloatSimple)
+TEST(PReluTest, FloatSimple)
{
Check<float>(/*input_shape=*/{2, 3}, /*alpha_shape=*/{2, 3},
/*output_shape=*/{2, 3},
SUCCEED();
}
-TEST(PreluTest, FloatBroadcast)
+TEST(PReluTest, FloatBroadcast)
{
Check<float>(/*input_shape=*/{1, 2, 2, 3}, /*alpha_shape=*/{1, 1, 3},
/*output_shape=*/{1, 2, 2, 3},
float GetTolerance(float min, float max) { return (max - min) / 255.0; }
-TEST(PreluTest, Uint8Simple)
+TEST(PReluTest, Uint8Simple)
{
std::vector<float> input_data{-0.8f, 0.2f, 0.9f, 0.7f, 0.1f, -0.4f};
std::vector<float> alpha_data{0.5f, 0.5f, 0.5f, 0.25f, 1.0f, 0.25f};
makeInputTensor<DataType::U8>({1, 2, 3, 1}, quant_param.first, quant_param.second, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
SUCCEED();
}
-TEST(PreluTest, Uint8Broadcast)
+TEST(PReluTest, Uint8Broadcast)
{
std::vector<float> input_data{
0.0f, 0.0f, 0.0f, // Row 1, Column 1
makeInputTensor<DataType::U8>({1, 1, 3}, quant_param.first, quant_param.second, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
::testing::ElementsAreArray(ref_quant_output_data));
}
-TEST(PreluTest, SInt16_LWQ_NEG)
+TEST(PReluTest, SInt16_LWQ_NEG)
{
// Rewrite this test in case layer-wise quantization for sint16 is supported
std::vector<float> input_data(6); // data is not important
Tensor alpha_tensor = makeInputTensor<DataType::S16>({1, 2, 3, 1}, 0.1, 0, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.1, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, SInt16_CWQ_Simple)
+TEST(PReluTest, SInt16_CWQ_Simple)
{
std::vector<float> input_data{-0.8f, 0.2f, 0.9f, -0.7f, 0.1f, -0.4f};
std::vector<float> alpha_data{0.5f, 0.25f};
Tensor alpha_tensor = makeInputTensor<DataType::S16>({2}, alpha_scales, zerop, 0, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.025, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
}
-TEST(PreluTest, SInt16_CWQ_spatial_alpha_NEG)
+TEST(PReluTest, SInt16_CWQ_spatial_alpha_NEG)
{
std::vector<float> input_data(6); // data is not important
std::vector<float> alpha_data(6);
makeInputTensor<DataType::S16>({1, 1, 3, 2}, alpha_scales, zerop, 3, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.1, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, SInt16_CWQ_wrong_dim_quant_NEG)
+TEST(PReluTest, SInt16_CWQ_wrong_dim_quant_NEG)
{
std::vector<float> input_data(6); // data is not important
std::vector<float> alpha_data(6);
makeInputTensor<DataType::S16>({1, 1, 1, 2}, alpha_scales, zerop, 1, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.1, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, SInt16_CWQ_uneven_shape1)
+TEST(PReluTest, SInt16_CWQ_uneven_shape1)
{
std::vector<float> input_data{-0.8f, 0.2f, 0.9f, -0.7f, 0.1f, -0.4f};
std::vector<float> alpha_data{0.5f, 0.25f};
makeInputTensor<DataType::S16>({1, 1, 2}, alpha_scales, zerop, 2, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.025, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
}
-TEST(PreluTest, SInt16_CWQ_uneven_shape2)
+TEST(PReluTest, SInt16_CWQ_uneven_shape2)
{
std::vector<float> input_data{
0.0f, 0.0f, 0.0f, // Row 1, Column 1
makeInputTensor<DataType::S16>({1, 1, 1, 3}, alpha_scales, zerop, 3, alpha_data);
Tensor output_tensor = makeOutputTensor(DataType::S16, 0.001, 0);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
kernel.execute();
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
}
-TEST(PreluTest, Input_Output_Type_NEG)
+TEST(PReluTest, Input_Output_Type_NEG)
{
Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f});
Tensor alpha_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f});
Tensor output_tensor = makeOutputTensor(DataType::U8);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, Input_Alpha_Type_NEG)
+TEST(PReluTest, Input_Alpha_Type_NEG)
{
Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1}, {1.f});
Tensor alpha_tensor = makeInputTensor<DataType::U8>({1}, {1});
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, Invalid_Input_Type_NEG)
+TEST(PReluTest, Invalid_Input_Type_NEG)
{
Tensor input_tensor = makeInputTensor<DataType::S64>({1}, {1});
Tensor alpha_tensor = makeInputTensor<DataType::S64>({1}, {1});
Tensor output_tensor = makeOutputTensor(DataType::S64);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
kernel.configure();
EXPECT_ANY_THROW(kernel.execute());
}
-TEST(PreluTest, Input_Output_U8_CWQ_NEG)
+TEST(PReluTest, Input_Output_U8_CWQ_NEG)
{
std::vector<float> scales{1.f, 1.f};
std::vector<int32_t> zerop{0, 0};
Tensor alpha_tensor = makeInputTensor<DataType::U8>({2, 2}, scales, zerop, 0, dummy_data);
Tensor output_tensor = makeInputTensor<DataType::U8>({2, 2}, scales, zerop, 0, dummy_data);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, Input_Output_S16_CWQ_NEG)
+TEST(PReluTest, Input_Output_S16_CWQ_NEG)
{
std::vector<float> scales{1.f, 1.f};
std::vector<int32_t> zerop{0, 0};
Tensor alpha_tensor = makeInputTensor<DataType::S16>({2, 2}, scales, zerop, 0, dummy_data);
Tensor output_tensor = makeInputTensor<DataType::S16>({2, 2}, scales, zerop, 0, dummy_data);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
-TEST(PreluTest, Mixing_U8_S16_NEG)
+TEST(PReluTest, Mixing_U8_S16_NEG)
{
std::vector<float> dummy_data(4, 0.f);
Tensor input_tensor = makeInputTensor<DataType::U8>({2, 2}, 1.f, 0, dummy_data);
Tensor alpha_tensor = makeInputTensor<DataType::S16>({2, 2}, 1.f, 0, dummy_data);
Tensor output_tensor = makeInputTensor<DataType::U8>({2, 2}, 1.f, 0, dummy_data);
- Prelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
+ PRelu kernel(&input_tensor, &alpha_tensor, &output_tensor);
EXPECT_ANY_THROW(kernel.configure());
}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/PadV2.h"
+
+#include "kernels/Utils.h"
+
+#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+PadV2::PadV2(const Tensor *input, const Tensor *paddings, const Tensor *constant_values,
+ Tensor *output)
+ : Kernel({input, paddings, constant_values}, {output})
+{
+}
+
+void PadV2::configure()
+{
+ const Shape &input_shape = input()->shape();
+ const int num_dims = input_shape.num_dims();
+
+ if (num_dims > 4)
+ throw std::runtime_error("Unsupported number of dimensions.");
+
+ assert(output()->element_type() == input()->element_type());
+ assert(paddings()->element_type() == DataType::S32);
+ assert(constant_values()->element_type() == output()->element_type());
+ // Paddings shape should be [N, 2].
+ assert(paddings()->shape().num_dims() == 2);
+ assert(paddings()->shape().dim(0) == num_dims);
+ assert(paddings()->shape().dim(1) == 2);
+ // Constant values elements number should be 1.
+ assert(constant_values()->shape().num_elements() == 1);
+
+ Shape output_shape(num_dims);
+ const auto *paddings_data = getTensorData<int32_t>(paddings());
+ for (int i = 0; i < num_dims; ++i)
+ {
+ const int32_t padding_before = paddings_data[i * 2];
+ const int32_t padding_after = paddings_data[i * 2 + 1];
+ assert(padding_before >= 0 && padding_after >= 0);
+ output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
+ }
+
+ output()->resize(output_shape);
+}
+
+void PadV2::execute() const
+{
+ const int num_dims = input()->shape().num_dims();
+
+ tflite::PadParams params{};
+ params.left_padding_count = num_dims;
+ params.right_padding_count = num_dims;
+
+ const auto *paddings_data = getTensorData<int32_t>(paddings());
+ for (int i = num_dims - 1; i >= 0; --i)
+ {
+ params.left_padding[i] = paddings_data[i * 2];
+ params.right_padding[i] = paddings_data[i * 2 + 1];
+ }
+
+ switch (input()->element_type())
+ {
+ case DataType::FLOAT32:
+ {
+ const auto pad_value = getTensorData<float>(constant_values())[0];
+ tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<float>(input()),
+ &pad_value, getTensorShape(output()),
+ getTensorData<float>(output()));
+ break;
+ }
+ case DataType::U8:
+ {
+ assert(output()->zero_point() >= std::numeric_limits<uint8_t>::min());
+ assert(output()->zero_point() <= std::numeric_limits<uint8_t>::max());
+ const auto pad_value = getTensorData<uint8_t>(constant_values())[0];
+ tflite::reference_ops::Pad(params, getTensorShape(input()), getTensorData<uint8_t>(input()),
+ &pad_value, getTensorShape(output()),
+ getTensorData<uint8_t>(output()));
+ break;
+ }
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_PAD_V2_H
+#define LUCI_INTERPRETER_KERNELS_PAD_V2_H
+
+#include "core/Kernel.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class PadV2 : public Kernel
+{
+public:
+ PadV2(const Tensor *input, const Tensor *paddings, const Tensor *constant_values, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ const Tensor *paddings() const { return _inputs[1]; }
+ const Tensor *constant_values() const { return _inputs[2]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_PAD_V2_H
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/PadV2.h"
+#include "kernels/TestUtils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+float GetTolerance(float min, float max) { return (max - min) / 255.0; }
+
+TEST(PadV2, Uint8)
+{
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::pair<float, int32_t> quant_param = quantizationParams<uint8_t>(-1.0f, 1.0f);
+ std::vector<float> input_data{-0.8, 0.2, 0.9, 0.7, 0.1, -0.3};
+ std::vector<int32_t> paddings_data{0, 0, 0, 2, 1, 3, 0, 0};
+ std::vector<float> constant_values_data{0.5};
+ Tensor input_tensor =
+ makeInputTensor<DataType::U8>({1, 2, 3, 1}, quant_param.first, quant_param.second, input_data);
+ Tensor paddings_tensor = makeInputTensor<DataType::S32>({4, 2}, paddings_data);
+ Tensor constant_values =
+ makeInputTensor<DataType::U8>({1}, quant_param.first, quant_param.second, constant_values_data);
+ Tensor output_tensor = makeOutputTensor(DataType::U8, quant_param.first, quant_param.second);
+
+ PadV2 kernel(&input_tensor, &paddings_tensor, &constant_values, &output_tensor);
+ kernel.configure();
+ kernel.execute();
+
+ std::vector<float> ref_output_data = {
+ 0.5, -0.8, 0.2, 0.9, 0.5, 0.5, 0.5, 0.5, 0.7, 0.1, -0.3, 0.5, 0.5, 0.5, //
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}; //
+ EXPECT_THAT(dequantizeTensorData(output_tensor),
+ FloatArrayNear(ref_output_data, kQuantizedTolerance));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST(PadV2, Float)
+{
+ std::vector<float> input_data{1, 2, 3, 4, 5, 6};
+ std::vector<int32_t> paddings_data{1, 0, 0, 2, 0, 3, 0, 0};
+ std::vector<float> constant_values_data{7};
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>({1, 2, 3, 1}, input_data);
+ Tensor paddings_tensor = makeInputTensor<DataType::S32>({4, 2}, paddings_data);
+ Tensor constant_values = makeInputTensor<DataType::FLOAT32>({1}, constant_values_data);
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ PadV2 kernel(&input_tensor, &paddings_tensor, &constant_values, &output_tensor);
+ kernel.configure();
+ kernel.execute();
+
+ std::vector<float> ref_output_data{7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
+ 7, 7, 7, 7, 7, 7, 7, 7, 1, 2, 3, 7, 7, 7, 4, 5,
+ 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7};
+ std::initializer_list<int32_t> ref_output_shape{2, 4, 6, 1};
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
* limitations under the License.
*/
-#include "kernels/Reverse.h"
+#include "kernels/ReverseV2.h"
#include "kernels/Utils.h"
#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
namespace kernels
{
-Reverse::Reverse(const Tensor *input, const Tensor *axes, Tensor *output)
+ReverseV2::ReverseV2(const Tensor *input, const Tensor *axes, Tensor *output)
: Kernel({input, axes}, {output})
{
}
-void Reverse::configure()
+void ReverseV2::configure()
{
assert(axes()->shape().num_dims() == 1);
assert(input()->shape().num_dims() >= axes()->shape().num_elements());
output()->resize(input()->shape());
}
-void Reverse::execute() const
+void ReverseV2::execute() const
{
int axis_value = getTensorData<int32_t>(axes())[0];
switch (output()->element_type())
namespace kernels
{
-class Reverse : public Kernel
+class ReverseV2 : public Kernel
{
public:
- Reverse(const Tensor *input, const Tensor *axes, Tensor *output);
+ ReverseV2(const Tensor *input, const Tensor *axes, Tensor *output);
const Tensor *input() const { return _inputs[0]; }
const Tensor *axes() const { return _inputs[1]; }
* limitations under the License.
*/
-#include "kernels/Reverse.h"
+#include "kernels/ReverseV2.h"
#include "kernels/TestUtils.h"
namespace luci_interpreter
using namespace testing;
-template <typename T> class ReverseTest : public ::testing::Test
+template <typename T> class ReverseV2Test : public ::testing::Test
{
};
using DataTypes = ::testing::Types<float, uint8_t>;
-TYPED_TEST_CASE(ReverseTest, DataTypes);
+TYPED_TEST_CASE(ReverseV2Test, DataTypes);
-TYPED_TEST(ReverseTest, MultiDimensions)
+TYPED_TEST(ReverseV2Test, MultiDimensions)
{
// TypeParam
std::vector<TypeParam> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
Tensor output_tensor = makeOutputTensor(getElementType<TypeParam>());
- Reverse kernel = Reverse(&input_tensor, &axis_tensor, &output_tensor);
+ ReverseV2 kernel = ReverseV2(&input_tensor, &axis_tensor, &output_tensor);
kernel.configure();
kernel.execute();
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Square.h"
+#include "kernels/Utils.h"
+
+#include <stdexcept>
+#include <cmath>
+
+namespace luci_interpreter
+{
+
+namespace kernels
+{
+
+Square::Square(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}
+
+void Square::configure()
+{
+ if (input()->element_type() != output()->element_type())
+ {
+ throw std::runtime_error("Input/output tensor data type mismatch.");
+ }
+ output()->resize(input()->shape());
+}
+
+void Square::execute() const
+{
+ switch (input()->element_type())
+ {
+ case DataType::FLOAT32:
+ evalFloat();
+ break;
+
+ default:
+ throw std::runtime_error("Unsupported type.");
+ }
+}
+
+void Square::evalFloat() const
+{
+ auto in = getTensorData<float>(input());
+ auto out = getTensorData<float>(output());
+ auto size = getTensorShape(input()).FlatSize();
+ for (auto i = in; i != in + size; ++i)
+ {
+ *out = (*i) * (*i);
+ ++out;
+ }
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_SQUARE_H
+#define LUCI_INTERPRETER_KERNELS_SQUARE_H
+
+#include "core/Kernel.h"
+#include "core/KernelParams.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class Square : public Kernel
+{
+public:
+ Square(const Tensor *input, Tensor *output);
+
+ const Tensor *input() const { return _inputs[0]; }
+ Tensor *output() const { return _outputs[0]; }
+
+ void configure() override;
+ void execute() const override;
+
+private:
+ void evalFloat() const;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_SQUARE_H
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/Square.h"
+#include "kernels/TestUtils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+TEST(SquareTest, Float)
+{
+ Shape input_shape{3, 1, 2};
+ std::vector<float> input_data1{1.0, 0.0, -1.0, 11.0, -2.0, -1.44};
+ Tensor input_tensor = makeInputTensor<DataType::FLOAT32>(input_shape, input_data1);
+ Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);
+
+ Square kernel(&input_tensor, &output_tensor);
+ kernel.configure();
+ kernel.execute();
+
+ std::vector<float> ref_output_data{1.0, 0.0, 1.0, 121.0, 4.0, 2.0736};
+ EXPECT_THAT(extractTensorData<float>(output_tensor), FloatArrayNear(ref_output_data));
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
switch (activation)
{
case Activation::NONE:
+ case Activation::TANH:
*activation_min = std::numeric_limits<float>::lowest();
*activation_max = std::numeric_limits<float>::max();
break;
switch (activation)
{
case Activation::NONE:
+ case Activation::TANH:
*activation_min = qmin;
*activation_max = qmax;
break;
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "kernels/While.h"
+#include "kernels/Utils.h"
+
+#include <cstring>
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+namespace
+{
+
+void copy(const std::vector<const Tensor *> &src, const std::vector<Tensor *> &dst)
+{
+ for (size_t i = 0; i < src.size(); ++i)
+ {
+ LUCI_INTERPRETER_CHECK(dst[i]->element_type() == src[i]->element_type());
+ dst[i]->resize(src[i]->shape());
+
+ const int32_t num_elements = src[i]->shape().num_elements();
+ const std::size_t element_size = getDataTypeSize(src[i]->element_type());
+ std::memcpy(dst[i]->data<void>(), src[i]->data<void>(), num_elements * element_size);
+ }
+}
+
+void copy(const std::vector<Tensor *> &src, const std::vector<Tensor *> &dst)
+{
+ std::vector<const Tensor *> const_src;
+ for (const auto &t : src)
+ const_src.push_back(t);
+ copy(const_src, dst);
+}
+
+} // namespace
+
+While::While(std::vector<const Tensor *> inputs, std::vector<Tensor *> outputs,
+ RuntimeGraph *cond_graph, RuntimeGraph *body_graph)
+ : Kernel(std::move(inputs), std::move(outputs)), _cond_graph(cond_graph), _body_graph(body_graph)
+{
+}
+
+void While::configure()
+{
+ LUCI_INTERPRETER_CHECK(_body_graph->getInputTensors().size() == getInputTensors().size());
+ LUCI_INTERPRETER_CHECK(_body_graph->getOutputTensors().size() == getOutputTensors().size());
+ LUCI_INTERPRETER_CHECK(_body_graph->getOutputTensors().size() == getInputTensors().size());
+
+ LUCI_INTERPRETER_CHECK(_cond_graph->getInputTensors().size() == getInputTensors().size());
+
+ const auto &cond_outputs = _cond_graph->getOutputTensors();
+ LUCI_INTERPRETER_CHECK(cond_outputs.size() == 1)
+ LUCI_INTERPRETER_CHECK(cond_outputs[0]->element_type() == DataType::BOOL);
+}
+
+/**
+ * @note Dynamic shape such as {1, 0, 8} may fail in tensor->data()
+ */
+void While::execute() const
+{
+ const auto &cond_inputs = _cond_graph->getInputTensors();
+ const auto &cond_outputs = _cond_graph->getOutputTensors();
+
+ copy(getInputTensors(), cond_inputs);
+
+ const auto &body_inputs = _body_graph->getInputTensors();
+ const auto &body_outputs = _body_graph->getOutputTensors();
+
+ while (true)
+ {
+ _cond_graph->execute();
+
+ bool cond_value = cond_outputs[0]->data<bool>()[0];
+ if (!cond_value)
+ break;
+
+ copy(cond_inputs, body_inputs);
+
+ _body_graph->execute();
+
+ copy(body_outputs, cond_inputs);
+ }
+
+ copy(cond_inputs, getOutputTensors());
+}
+
+} // namespace kernels
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_KERNELS_WHILE_H
+#define LUCI_INTERPRETER_KERNELS_WHILE_H
+
+#include "core/Kernel.h"
+#include "core/RuntimeGraph.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+
+class While : public Kernel
+{
+public:
+ While(std::vector<const Tensor *> inputs, std::vector<Tensor *> outputs, RuntimeGraph *cond_graph,
+ RuntimeGraph *body_graph);
+
+ const Tensor *input(int index) const { return _inputs[index]; }
+ Tensor *output(int index) const { return _outputs[index]; }
+
+ void configure() override;
+ void execute() const override;
+
+private:
+ RuntimeGraph *const _cond_graph = nullptr;
+ RuntimeGraph *const _body_graph = nullptr;
+};
+
+} // namespace kernels
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_KERNELS_WHILE_H
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/RuntimeModule.h"
+#include "kernels/Add.h"
+#include "kernels/Less.h"
+#include "kernels/While.h"
+#include "kernels/TestUtils.h"
+
+namespace luci_interpreter
+{
+namespace kernels
+{
+namespace
+{
+
+using namespace testing;
+
+RuntimeGraph *buildCondSubgraph(RuntimeModule *module, DataType dtype, Tensor *input_cond)
+{
+ RuntimeGraph *graph = module->addGraph();
+ Tensor *input =
+ graph->addTensor(std::make_unique<Tensor>(dtype, Shape{}, AffineQuantization{}, ""));
+ Tensor *output =
+ graph->addTensor(std::make_unique<Tensor>(DataType::BOOL, Shape{}, AffineQuantization{}, ""));
+
+ graph->setInputTensors({input});
+ graph->setOutputTensors({output});
+
+ graph->addKernel(std::make_unique<Less>(input, input_cond, output));
+
+ return graph;
+}
+
+RuntimeGraph *buildBodySubgraph(RuntimeModule *module, DataType dtype, Tensor *input_add)
+{
+ RuntimeGraph *graph = module->addGraph();
+ Tensor *input =
+ graph->addTensor(std::make_unique<Tensor>(dtype, Shape{}, AffineQuantization{}, ""));
+ Tensor *output =
+ graph->addTensor(std::make_unique<Tensor>(dtype, Shape{}, AffineQuantization{}, ""));
+
+ graph->setInputTensors({input});
+ graph->setOutputTensors({output});
+
+ AddParams params{};
+ params.activation = Activation::NONE;
+ graph->addKernel(std::make_unique<Add>(input, input_add, output, params));
+
+ return graph;
+}
+
+TEST(WhileTest, FloatLoop10)
+{
+ Tensor input = makeInputTensor<DataType::FLOAT32>({1}, {1});
+ Tensor output = makeOutputTensor(DataType::FLOAT32);
+
+ Tensor input_cond = makeInputTensor<DataType::FLOAT32>({1}, {10});
+ Tensor input_add = makeInputTensor<DataType::FLOAT32>({1}, {1});
+
+ RuntimeModule module(nullptr);
+ RuntimeGraph *cond_graph = buildCondSubgraph(&module, DataType::FLOAT32, &input_cond);
+ RuntimeGraph *body_graph = buildBodySubgraph(&module, DataType::FLOAT32, &input_add);
+
+ While kernel({&input}, {&output}, cond_graph, body_graph);
+ kernel.configure();
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorData<float>(output), FloatArrayNear({10}));
+}
+
+} // namespace
+} // namespace kernels
+} // namespace luci_interpreter
set(SOURCES
GraphLoader.h
GraphLoader.cpp
+ KernelBuilderHelper.h
+ KernelBuilderHelper.cpp
KernelBuilder.h
KernelBuilder.cpp
ModuleLoader.h
return getNodeDataImpl<DataType::S32>(node, data_size);
case DataType::S64:
return getNodeDataImpl<DataType::S64>(node, data_size);
+ case DataType::BOOL:
+ return getNodeDataImpl<DataType::BOOL>(node, data_size);
default:
throw std::runtime_error("Unsupported type.");
}
case luci::CircleOpcode::CIRCLEIFOUT:
case luci::CircleOpcode::CIRCLESPLITOUT:
case luci::CircleOpcode::CIRCLEUNPACKOUT:
+ case luci::CircleOpcode::CIRCLEWHILEOUT:
return false;
default:
return true;
if (isExecutableNode(node))
{
- std::unique_ptr<Kernel> kernel = node->accept(&kernel_builder);
+ std::unique_ptr<Kernel> kernel = kernel_builder.build(node);
_runtime_to_ir.kernel_to_node.emplace(kernel.get(), node);
_runtime_graph->addKernel(std::move(kernel));
}
#include "kernels/ArgMax.h"
#include "kernels/AveragePool2D.h"
#include "kernels/BatchToSpaceND.h"
+#include "kernels/Cast.h"
#include "kernels/Concatenation.h"
#include "kernels/Conv2D.h"
#include "kernels/DepthToSpace.h"
#include "kernels/MaxPool2D.h"
#include "kernels/Mean.h"
#include "kernels/Minimum.h"
+#include "kernels/MirrorPad.h"
#include "kernels/Mul.h"
#include "kernels/Neg.h"
#include "kernels/NotEqual.h"
#include "kernels/Pack.h"
#include "kernels/Pad.h"
+#include "kernels/PadV2.h"
#include "kernels/Pow.h"
-#include "kernels/Prelu.h"
+#include "kernels/PRelu.h"
#include "kernels/Relu.h"
#include "kernels/Relu6.h"
#include "kernels/Reshape.h"
#include "kernels/ResizeBilinear.h"
#include "kernels/ResizeNearestNeighbor.h"
-#include "kernels/Reverse.h"
+#include "kernels/ReverseV2.h"
#include "kernels/Rsqrt.h"
#include "kernels/Slice.h"
#include "kernels/Softmax.h"
#include "kernels/Split.h"
#include "kernels/StridedSlice.h"
#include "kernels/Sqrt.h"
+#include "kernels/Square.h"
#include "kernels/SquaredDifference.h"
#include "kernels/Squeeze.h"
#include "kernels/Sub.h"
#include "kernels/Unpack.h"
#include "kernels/Transpose.h"
#include "kernels/TransposeConv.h"
+#include "kernels/While.h"
#include <stdexcept>
-namespace luci_interpreter
+namespace
{
template <typename CircleNodeOut>
-static std::vector<const loco::Node *> collectOutputNodes(const luci::CircleNode *node)
+std::vector<const loco::Node *> collectOutputNodes(const luci::CircleNode *node)
{
std::vector<const CircleNodeOut *> output_nodes;
for (const loco::Node *loco_node : loco::succs(node))
return {output_nodes.cbegin(), output_nodes.cend()};
}
-const Tensor *KernelBuilder::getInputTensor(const loco::Node *node) const
+} // namespace
+
+namespace luci_interpreter
{
- const Tensor *tensor = _node_to_tensor.at(node);
- assert(tensor != nullptr);
- return tensor;
-}
-const Tensor *KernelBuilder::getOptionalInputTensor(const loco::Node *node) const
+// TODO move to anonymous namespace
+enum class KB
{
- if (dynamic_cast<const luci::CircleOutputExclude *>(node))
+ ABC,
+ DEF,
+ GHIJ,
+ KLMN,
+ OPQR,
+ STUV,
+ WXYZ,
+};
+
+#define DECLARE_VISIT(CLASS) std::unique_ptr<Kernel> visit(const luci::CLASS *) override
+
+template <KB kb> class KernelBuilderLet;
+
+template <>
+class KernelBuilderLet<KB::ABC> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
{
- return nullptr;
}
- return getInputTensor(node);
-}
-Tensor *KernelBuilder::getOutputTensor(const loco::Node *node) const
-{
- Tensor *tensor = _node_to_tensor.at(node);
- assert(tensor != nullptr);
- return tensor;
-}
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleAdd);
+ DECLARE_VISIT(CircleArgMax);
+ DECLARE_VISIT(CircleAveragePool2D);
+ DECLARE_VISIT(CircleBatchToSpaceND);
+ DECLARE_VISIT(CircleCast);
+ DECLARE_VISIT(CircleConcatenation);
+ DECLARE_VISIT(CircleConst);
+ DECLARE_VISIT(CircleConv2D);
+};
+
+template <>
+class KernelBuilderLet<KB::DEF> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
-std::vector<Tensor *>
-KernelBuilder::getOutputTensors(const std::vector<const loco::Node *> &nodes) const
-{
- std::vector<Tensor *> tensors;
- tensors.reserve(nodes.size());
- for (const loco::Node *node : nodes)
- tensors.push_back(getOutputTensor(node));
- return tensors;
-}
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleDepthToSpace);
+ DECLARE_VISIT(CircleDepthwiseConv2D);
+ DECLARE_VISIT(CircleDiv);
+ DECLARE_VISIT(CircleElu);
+ DECLARE_VISIT(CircleEqual);
+ DECLARE_VISIT(CircleExp);
+ DECLARE_VISIT(CircleFloor);
+ DECLARE_VISIT(CircleFloorDiv);
+ DECLARE_VISIT(CircleFullyConnected);
+};
+
+template <>
+class KernelBuilderLet<KB::GHIJ> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
-RuntimeGraph *KernelBuilder::getRuntimeGraph(const loco::Graph *graph) const
-{
- RuntimeGraph *runtime_graph = _graph_to_runtime_graph.at(graph);
- assert(runtime_graph != nullptr);
- return runtime_graph;
-}
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleGreater);
+ DECLARE_VISIT(CircleGreaterEqual);
+ DECLARE_VISIT(CircleIf);
+ DECLARE_VISIT(CircleInput);
+ DECLARE_VISIT(CircleInstanceNorm);
+};
+
+template <>
+class KernelBuilderLet<KB::KLMN> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
+
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleL2Normalize);
+ DECLARE_VISIT(CircleL2Pool2D);
+ DECLARE_VISIT(CircleLeakyRelu);
+ DECLARE_VISIT(CircleLess);
+ DECLARE_VISIT(CircleLessEqual);
+ DECLARE_VISIT(CircleLocalResponseNormalization);
+ DECLARE_VISIT(CircleLogSoftmax);
+ DECLARE_VISIT(CircleLogicalAnd);
+ DECLARE_VISIT(CircleLogicalNot);
+ DECLARE_VISIT(CircleLogicalOr);
+ DECLARE_VISIT(CircleLogistic);
+ DECLARE_VISIT(CircleMaxPool2D);
+ DECLARE_VISIT(CircleMaximum);
+ DECLARE_VISIT(CircleMean);
+ DECLARE_VISIT(CircleMinimum);
+ DECLARE_VISIT(CircleMirrorPad);
+ DECLARE_VISIT(CircleMul);
+ DECLARE_VISIT(CircleNeg);
+ DECLARE_VISIT(CircleNotEqual);
+};
+
+template <>
+class KernelBuilderLet<KB::OPQR> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleNode *)
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleOutput);
+ DECLARE_VISIT(CirclePRelu);
+ DECLARE_VISIT(CirclePack);
+ DECLARE_VISIT(CirclePad);
+ DECLARE_VISIT(CirclePadV2);
+ DECLARE_VISIT(CirclePow);
+ DECLARE_VISIT(CircleRelu);
+ DECLARE_VISIT(CircleRelu6);
+ DECLARE_VISIT(CircleReshape);
+ DECLARE_VISIT(CircleResizeBilinear);
+ DECLARE_VISIT(CircleResizeNearestNeighbor);
+ DECLARE_VISIT(CircleReverseV2);
+ DECLARE_VISIT(CircleRsqrt);
+};
+
+template <>
+class KernelBuilderLet<KB::STUV> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
+
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleSlice);
+ DECLARE_VISIT(CircleSoftmax);
+ DECLARE_VISIT(CircleSpaceToBatchND);
+ DECLARE_VISIT(CircleSpaceToDepth);
+ DECLARE_VISIT(CircleSplit);
+ DECLARE_VISIT(CircleSqrt);
+ DECLARE_VISIT(CircleSquare);
+ DECLARE_VISIT(CircleSquaredDifference);
+ DECLARE_VISIT(CircleSqueeze);
+ DECLARE_VISIT(CircleStridedSlice);
+ DECLARE_VISIT(CircleSub);
+ DECLARE_VISIT(CircleTanh);
+ DECLARE_VISIT(CircleTranspose);
+ DECLARE_VISIT(CircleTransposeConv);
+ DECLARE_VISIT(CircleUnpack);
+};
+
+template <>
+class KernelBuilderLet<KB::WXYZ> : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>,
+ public KernelBuilderHelper
+{
+public:
+ KernelBuilderLet(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
+ {
+ }
+
+public:
+ std::unique_ptr<Kernel> visit(const luci::CircleNode *) { return nullptr; }
+
+public:
+ DECLARE_VISIT(CircleWhile);
+};
+
+#undef DECLARE_VISIT
+
+std::unique_ptr<Kernel> KernelBuilder::build(const luci::CircleNode *node)
{
- throw std::invalid_argument("Unsupported operator.");
+#define VISIT_KB(GRP) \
+ do \
+ { \
+ KernelBuilderLet<KB::GRP> kbl(graph_to_runtime_graph(), node_to_tensor()); \
+ auto ret = node->accept(&kbl); \
+ if (ret != nullptr) \
+ return ret; \
+ } while (false)
+
+ VISIT_KB(ABC);
+ VISIT_KB(DEF);
+ VISIT_KB(GHIJ);
+ VISIT_KB(KLMN);
+ VISIT_KB(OPQR);
+ VISIT_KB(STUV);
+ VISIT_KB(WXYZ);
+
+#undef VISIT_KB
+ std::string msg = "Unsupported operator: ";
+ msg += std::to_string(static_cast<uint32_t>(node->opcode())) + " " + std::string(node->name());
+ throw std::invalid_argument(msg.c_str());
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleAdd *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleAdd *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Add>(input1, input2, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleArgMax *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleArgMax *node)
{
assert(node->arity() == 2);
const Tensor *input = getInputTensor(node->input());
return std::make_unique<kernels::ArgMax>(input, axis, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleAveragePool2D *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleAveragePool2D *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::AveragePool2D>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleBatchToSpaceND *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleBatchToSpaceND *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::BatchToSpaceND>(input, block_shape, crops, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleConcatenation *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleCast *node)
+{
+ assert(node->arity() == 1);
+
+ const Tensor *input = getInputTensor(node->x());
+ Tensor *output = getOutputTensor(node);
+
+ return std::make_unique<kernels::Cast>(input, output);
+}
+
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleConcatenation *node)
{
std::vector<const Tensor *> inputs(node->numValues());
for (uint32_t i = 0; i < node->numValues(); ++i)
return std::make_unique<kernels::Concatenation>(std::move(inputs), output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleConst *)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleConst *)
{
throw std::runtime_error("Const node cannot be executed.");
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleConv2D *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::ABC>::visit(const luci::CircleConv2D *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::Conv2D>(input, filter, bias, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleDepthToSpace *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleDepthToSpace *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::DepthToSpace>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleDepthwiseConv2D *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleDepthwiseConv2D *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::DepthwiseConv2D>(input, filter, bias, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleDiv *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleDiv *node)
{
assert(node->arity() == 2);
const Tensor *input1 = getInputTensor(node->x());
return std::make_unique<kernels::Div>(input1, input2, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleElu *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleElu *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Elu>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleExp *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleEqual *node)
{
- assert(node->arity() == 1);
+ assert(node->arity() == 2);
- const Tensor *input = getInputTensor(node->x());
+ const Tensor *x = getInputTensor(node->x());
+ const Tensor *y = getInputTensor(node->y());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::Exp>(input, output);
+ return std::make_unique<kernels::Equal>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleFloor *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleExp *node)
{
assert(node->arity() == 1);
const Tensor *input = getInputTensor(node->x());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::Floor>(input, output);
+ return std::make_unique<kernels::Exp>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleFloorDiv *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleFloor *node)
{
- assert(node->arity() == 2);
+ assert(node->arity() == 1);
- const Tensor *x = getInputTensor(node->x());
- const Tensor *y = getInputTensor(node->y());
+ const Tensor *input = getInputTensor(node->x());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::FloorDiv>(x, y, output);
+ return std::make_unique<kernels::Floor>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleEqual *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleFloorDiv *node)
{
assert(node->arity() == 2);
const Tensor *y = getInputTensor(node->y());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::Equal>(x, y, output);
+ return std::make_unique<kernels::FloorDiv>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleFullyConnected *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::DEF>::visit(const luci::CircleFullyConnected *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::FullyConnected>(input, weights, bias, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleGreater *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::GHIJ>::visit(const luci::CircleGreater *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Greater>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleGreaterEqual *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::GHIJ>::visit(const luci::CircleGreaterEqual *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::GreaterEqual>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleIf *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::GHIJ>::visit(const luci::CircleIf *node)
{
auto output_nodes = collectOutputNodes<luci::CircleIfOut>(node);
assert(node->arity() == 1 + node->input_count());
else_graph);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleInstanceNorm *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::GHIJ>::visit(const luci::CircleInstanceNorm *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::InstanceNorm>(input, gamma, beta, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleInput *)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::GHIJ>::visit(const luci::CircleInput *)
{
throw std::runtime_error("Input node cannot be executed.");
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleL2Normalize *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleL2Normalize *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::L2Normalize>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleL2Pool2D *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleL2Pool2D *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::L2Pool2D>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLeakyRelu *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLeakyRelu *node)
{
assert(node->arity() == 1);
const Tensor *input = getInputTensor(node->features());
return std::make_unique<kernels::LeakyRelu>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLess *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLess *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Less>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLessEqual *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLessEqual *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::LessEqual>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLocalResponseNormalization *node)
+std::unique_ptr<Kernel>
+KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLocalResponseNormalization *node)
{
assert(node->arity() == 1);
const Tensor *input = getInputTensor(node->input());
return std::make_unique<kernels::LocalResponseNormalization>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLogicalAnd *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLogicalAnd *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::LogicalAnd>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLogicalNot *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLogicalNot *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::LogicalNot>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLogicalOr *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLogicalOr *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::LogicalOr>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLogistic *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLogistic *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Logistic>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleLogSoftmax *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleLogSoftmax *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::LogSoftmax>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleMaximum *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMaximum *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Maximum>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleMaxPool2D *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMaxPool2D *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::MaxPool2D>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleMean *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMean *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Mean>(input, axes, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleMinimum *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMinimum *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Minimum>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleMul *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMirrorPad *node)
+{
+ assert(node->arity() == 2);
+
+ const Tensor *input = getInputTensor(node->input());
+ const Tensor *paddings = getInputTensor(node->paddings());
+ Tensor *output = getOutputTensor(node);
+
+ MirrorPadParams params{};
+ params.mode = node->mode();
+
+ return std::make_unique<kernels::MirrorPad>(input, paddings, output, params);
+}
+
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleMul *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Mul>(input1, input2, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleNeg *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleNeg *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Neg>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleNotEqual *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::KLMN>::visit(const luci::CircleNotEqual *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::NotEqual>(x, y, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleOutput *)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleOutput *)
{
throw std::runtime_error("Output node cannot be executed.");
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CirclePack *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CirclePack *node)
{
assert(node->arity() == node->values_count());
return std::make_unique<kernels::Pack>(std::move(inputs), output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CirclePad *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CirclePad *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Pad>(input, paddings, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CirclePow *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CirclePadV2 *node)
+{
+ assert(node->arity() == 3);
+
+ const Tensor *input = getInputTensor(node->input());
+ const Tensor *paddings = getInputTensor(node->paddings());
+ const Tensor *constant_values = getInputTensor(node->constant_values());
+ Tensor *output = getOutputTensor(node);
+
+ return std::make_unique<kernels::PadV2>(input, paddings, constant_values, output);
+}
+
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CirclePow *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Pow>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CirclePRelu *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CirclePRelu *node)
{
assert(node->arity() == 2);
const Tensor *alpha = getInputTensor(node->alpha());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::Prelu>(input, alpha, output);
+ return std::make_unique<kernels::PRelu>(input, alpha, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleRelu *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleRelu *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Relu>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleRelu6 *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleRelu6 *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Relu6>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleReshape *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleReshape *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Reshape>(input, shape, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleResizeBilinear *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleResizeBilinear *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::ResizeBilinear>(input, size, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleResizeNearestNeighbor *node)
+std::unique_ptr<Kernel>
+KernelBuilderLet<KB::OPQR>::visit(const luci::CircleResizeNearestNeighbor *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::ResizeNearestNeighbor>(input, size, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleReverseV2 *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleReverseV2 *node)
{
assert(node->arity() == 2);
const Tensor *axes = getInputTensor(node->axis());
Tensor *output = getOutputTensor(node);
- return std::make_unique<kernels::Reverse>(input, axes, output);
+ return std::make_unique<kernels::ReverseV2>(input, axes, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleRsqrt *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::OPQR>::visit(const luci::CircleRsqrt *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Rsqrt>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSlice *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSlice *node)
{
assert(node->arity() == 3);
return std::make_unique<kernels::Slice>(input, begin, size, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSoftmax *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSoftmax *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Softmax>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSpaceToBatchND *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSpaceToBatchND *node)
{
assert(node->arity() == 3);
;
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSpaceToDepth *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSpaceToDepth *node)
{
assert(node->arity() == 1);
const Tensor *input = getInputTensor(node->input());
return std::make_unique<kernels::SpaceToDepth>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSplit *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSplit *node)
{
auto output_nodes = collectOutputNodes<luci::CircleSplitOut>(node);
assert(node->arity() == 2);
return std::make_unique<kernels::Split>(axis, input, std::move(outputs));
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSqrt *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSqrt *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Sqrt>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSquaredDifference *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSquare *node)
+{
+ assert(node->arity() == 1);
+
+ const Tensor *input = getInputTensor(node->x());
+ Tensor *output = getOutputTensor(node);
+
+ return std::make_unique<kernels::Square>(input, output);
+}
+
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSquaredDifference *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::SquaredDifference>(input1, input2, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSqueeze *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSqueeze *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Squeeze>(input, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleStridedSlice *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleStridedSlice *node)
{
assert(node->arity() == 4);
return std::make_unique<kernels::StridedSlice>(input, begin, end, strides, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleSub *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleSub *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Sub>(input1, input2, output, params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleTanh *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleTanh *node)
{
assert(node->arity() == 1);
return std::make_unique<kernels::Tanh>(input, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleTranspose *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleTranspose *node)
{
assert(node->arity() == 2);
return std::make_unique<kernels::Transpose>(input, perm, output);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleTransposeConv *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleTransposeConv *node)
{
assert(node->arity() == 4);
params);
}
-std::unique_ptr<Kernel> KernelBuilder::visit(const luci::CircleUnpack *node)
+std::unique_ptr<Kernel> KernelBuilderLet<KB::STUV>::visit(const luci::CircleUnpack *node)
{
auto output_nodes = collectOutputNodes<luci::CircleUnpackOut>(node);
assert(node->arity() == 1);
return std::make_unique<kernels::Unpack>(input, std::move(outputs), params);
}
+std::unique_ptr<Kernel> KernelBuilderLet<KB::WXYZ>::visit(const luci::CircleWhile *node)
+{
+ auto output_nodes = collectOutputNodes<luci::CircleWhileOut>(node);
+ assert(node->arity() == node->input_count());
+ assert(output_nodes.size() == static_cast<size_t>(node->output_count()));
+
+ std::vector<const Tensor *> inputs(node->input_count());
+ for (uint32_t i = 0; i < node->input_count(); ++i)
+ {
+ inputs[i] = getInputTensor(node->input(i));
+ }
+ std::vector<Tensor *> outputs = getOutputTensors(output_nodes);
+
+ RuntimeGraph *cond_graph = getRuntimeGraph(node->cond_graph());
+ RuntimeGraph *body_graph = getRuntimeGraph(node->body_graph());
+
+ return std::make_unique<kernels::While>(std::move(inputs), std::move(outputs), cond_graph,
+ body_graph);
+}
+
} // namespace luci_interpreter
#ifndef LUCI_INTERPRETER_LOADER_KERNELBUILDER_H
#define LUCI_INTERPRETER_LOADER_KERNELBUILDER_H
+#include "loader/KernelBuilderHelper.h"
+
#include "core/Kernel.h"
#include "core/RuntimeGraph.h"
#include <luci/IR/CircleNodeVisitor.h>
#include <memory>
-#include <vector>
#include <unordered_map>
namespace luci_interpreter
{
-class KernelBuilder : public luci::CircleNodeVisitor<std::unique_ptr<Kernel>>
+class KernelBuilder : public KernelBuilderHelper
{
public:
KernelBuilder(
const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
- : _graph_to_runtime_graph(graph_to_runtime_graph), _node_to_tensor(node_to_tensor)
+ : KernelBuilderHelper(graph_to_runtime_graph, node_to_tensor)
{
}
- std::unique_ptr<Kernel> visit(const luci::CircleNode *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleAdd *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleArgMax *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleAveragePool2D *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleBatchToSpaceND *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleConcatenation *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleConv2D *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleConst *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleDepthToSpace *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleDepthwiseConv2D *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleDiv *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleElu *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleExp *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleFloor *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleFloorDiv *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleEqual *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleFullyConnected *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleGreater *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleGreaterEqual *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleIf *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleInstanceNorm *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleL2Normalize *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleL2Pool2D *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLeakyRelu *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLess *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLessEqual *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLocalResponseNormalization *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLogicalAnd *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLogicalNot *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLogicalOr *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLogistic *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleLogSoftmax *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleInput *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleMaximum *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleMaxPool2D *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleMean *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleMinimum *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleMul *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleNeg *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleNotEqual *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleOutput *node) override;
- std::unique_ptr<Kernel> visit(const luci::CirclePack *node) override;
- std::unique_ptr<Kernel> visit(const luci::CirclePad *node) override;
- std::unique_ptr<Kernel> visit(const luci::CirclePow *node) override;
- std::unique_ptr<Kernel> visit(const luci::CirclePRelu *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleRelu *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleRelu6 *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleReshape *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleResizeBilinear *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleResizeNearestNeighbor *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleReverseV2 *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleRsqrt *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSlice *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSoftmax *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSpaceToBatchND *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSpaceToDepth *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSplit *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleStridedSlice *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSqrt *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSquaredDifference *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSqueeze *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleSub *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleTanh *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleTranspose *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleTransposeConv *node) override;
- std::unique_ptr<Kernel> visit(const luci::CircleUnpack *node) override;
-
-private:
- const Tensor *getInputTensor(const loco::Node *node) const;
-
- const Tensor *getOptionalInputTensor(const loco::Node *node) const;
-
- Tensor *getOutputTensor(const loco::Node *node) const;
-
- std::vector<Tensor *> getOutputTensors(const std::vector<const loco::Node *> &nodes) const;
-
- RuntimeGraph *getRuntimeGraph(const loco::Graph *graph) const;
-
-private:
- const std::unordered_map<const loco::Graph *, RuntimeGraph *> &_graph_to_runtime_graph;
- const std::unordered_map<const loco::Node *, Tensor *> &_node_to_tensor;
+ std::unique_ptr<Kernel> build(const luci::CircleNode *node);
};
} // namespace luci_interpreter
#include <kernels/Add.h>
#include <kernels/ArgMax.h>
#include <kernels/AveragePool2D.h>
+#include <kernels/Cast.h>
#include <kernels/Concatenation.h>
#include <kernels/Conv2D.h>
#include <kernels/DepthToSpace.h>
#include <kernels/Neg.h>
#include <kernels/NotEqual.h>
#include <kernels/Pad.h>
+#include <kernels/PadV2.h>
#include <kernels/Pow.h>
-#include <kernels/Prelu.h>
+#include <kernels/PRelu.h>
#include <kernels/Relu.h>
#include <kernels/Relu6.h>
#include <kernels/Reshape.h>
#include <kernels/ResizeBilinear.h>
#include <kernels/ResizeNearestNeighbor.h>
-#include <kernels/Reverse.h>
+#include <kernels/ReverseV2.h>
#include <kernels/Rsqrt.h>
#include <kernels/Slice.h>
#include <kernels/Softmax.h>
KernelBuilder kernel_builder(graph_to_runtime_graph, _node_to_tensor);
- auto kernel = op->accept(&kernel_builder);
+ auto kernel = kernel_builder.build(op);
return std::unique_ptr<KernelT>(dynamic_cast<KernelT *>(kernel.release()));
}
EXPECT_THAT(kernel->params().activation, Eq(op->fusedActivationFunction()));
}
+TEST_F(KernelBuilderTest, Cast)
+{
+ auto *input = createInputNode();
+
+ auto *op = createNode<luci::CircleCast>();
+ op->x(input);
+
+ auto kernel = buildKernel<kernels::Cast>(op);
+ ASSERT_THAT(kernel, NotNull());
+
+ checkTensor(kernel->input(), input);
+ checkTensor(kernel->output(), op);
+}
+
TEST_F(KernelBuilderTest, Concatenation)
{
auto *input1 = createInputNode();
checkTensor(kernel->output(), op);
}
+TEST_F(KernelBuilderTest, PadV2)
+{
+ auto *input = createInputNode();
+ auto *paddings = createInputNode();
+ auto *constant_values = createInputNode();
+
+ auto *op = createNode<luci::CirclePadV2>();
+ op->input(input);
+ op->paddings(paddings);
+ op->constant_values(constant_values);
+
+ auto kernel = buildKernel<kernels::PadV2>(op);
+ ASSERT_THAT(kernel, NotNull());
+
+ checkTensor(kernel->input(), input);
+ checkTensor(kernel->paddings(), paddings);
+ checkTensor(kernel->constant_values(), constant_values);
+ checkTensor(kernel->output(), op);
+}
+
TEST_F(KernelBuilderTest, Pow)
{
auto *input1 = createInputNode();
checkTensor(kernel->output(), op);
}
-TEST_F(KernelBuilderTest, Prelu)
+TEST_F(KernelBuilderTest, PRelu)
{
auto *input = createInputNode();
auto *alpha = createInputNode();
op->input(input);
op->alpha(alpha);
- auto kernel = buildKernel<kernels::Prelu>(op);
+ auto kernel = buildKernel<kernels::PRelu>(op);
ASSERT_THAT(kernel, NotNull());
checkTensor(kernel->input(), input);
op->tensor(input);
op->axis(axes);
- auto kernel = buildKernel<kernels::Reverse>(op);
+ auto kernel = buildKernel<kernels::ReverseV2>(op);
ASSERT_THAT(kernel, NotNull());
checkTensor(kernel->input(), input);
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loader/KernelBuilderHelper.h"
+
+#include <luci/IR/Nodes/CircleOutput.h>
+
+namespace luci_interpreter
+{
+
+const Tensor *KernelBuilderHelper::getInputTensor(const loco::Node *node) const
+{
+ const Tensor *tensor = _node_to_tensor.at(node);
+ assert(tensor != nullptr);
+ return tensor;
+}
+
+const Tensor *KernelBuilderHelper::getOptionalInputTensor(const loco::Node *node) const
+{
+ if (dynamic_cast<const luci::CircleOutputExclude *>(node))
+ {
+ return nullptr;
+ }
+ return getInputTensor(node);
+}
+
+Tensor *KernelBuilderHelper::getOutputTensor(const loco::Node *node) const
+{
+ Tensor *tensor = _node_to_tensor.at(node);
+ assert(tensor != nullptr);
+ return tensor;
+}
+
+std::vector<Tensor *>
+KernelBuilderHelper::getOutputTensors(const std::vector<const loco::Node *> &nodes) const
+{
+ std::vector<Tensor *> tensors;
+ tensors.reserve(nodes.size());
+ for (const loco::Node *node : nodes)
+ tensors.push_back(getOutputTensor(node));
+ return tensors;
+}
+
+RuntimeGraph *KernelBuilderHelper::getRuntimeGraph(const loco::Graph *graph) const
+{
+ RuntimeGraph *runtime_graph = _graph_to_runtime_graph.at(graph);
+ assert(runtime_graph != nullptr);
+ return runtime_graph;
+}
+
+} // namespace luci_interpreter
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_LOADER_KERNELBUILDER_HELPER_H
+#define LUCI_INTERPRETER_LOADER_KERNELBUILDER_HELPER_H
+
+#include "core/Kernel.h"
+#include "core/RuntimeGraph.h"
+
+#include <loco/IR/Graph.h>
+#include <loco/IR/Node.h>
+
+#include <vector>
+#include <unordered_map>
+
+namespace luci_interpreter
+{
+
+class KernelBuilderHelper
+{
+public:
+ KernelBuilderHelper(
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph,
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor)
+ : _graph_to_runtime_graph(graph_to_runtime_graph), _node_to_tensor(node_to_tensor)
+ {
+ }
+
+protected:
+ const Tensor *getInputTensor(const loco::Node *node) const;
+ const Tensor *getOptionalInputTensor(const loco::Node *node) const;
+
+ Tensor *getOutputTensor(const loco::Node *node) const;
+ std::vector<Tensor *> getOutputTensors(const std::vector<const loco::Node *> &nodes) const;
+
+ RuntimeGraph *getRuntimeGraph(const loco::Graph *graph) const;
+
+protected:
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &graph_to_runtime_graph() const
+ {
+ return _graph_to_runtime_graph;
+ }
+
+ const std::unordered_map<const loco::Node *, Tensor *> &node_to_tensor() const
+ {
+ return _node_to_tensor;
+ }
+
+private:
+ const std::unordered_map<const loco::Graph *, RuntimeGraph *> &_graph_to_runtime_graph;
+ const std::unordered_map<const loco::Node *, Tensor *> &_node_to_tensor;
+};
+
+} // namespace luci_interpreter
+
+#endif // LUCI_INTERPRETER_LOADER_KERNELBUILDER_HELPER_H
addeval(Net_Conv_Add_Mul_001 fuse_batchnorm_with_conv)
addeval(Net_Conv_Add_Mul_002 fuse_batchnorm_with_conv)
addeval(Net_Conv_Min_Max_000 transform_min_max_to_relu6)
+addeval(Net_Conv_Min_Relu_000 transform_min_relu_to_relu6)
addeval(Net_Conv_Relu6_000 fuse_activation_function)
addeval(Net_DwConv_BN_000 fuse_batchnorm_with_dwconv)
addeval(Net_DwConv_BN_001 fuse_batchnorm_with_dwconv)
#addeval(While_003)
#addeval(YUV_TO_RGB_U8_000)
#addeval(ZerosLike_000)
+
+# Simple Network test
+addeval(Part_While_000)
+addeval(Part_While_001)
add_subdirectory(env)
add_subdirectory(log)
add_subdirectory(lang)
+add_subdirectory(logex)
add_subdirectory(testhelper)
add_subdirectory(service)
add_subdirectory(pass)
add_subdirectory(profile)
add_subdirectory(partition)
-add_subdirectory(logex)
add_subdirectory(import)
add_subdirectory(export)
add_subdirectory(tester)
target_include_directories(luci_env PUBLIC include)
target_link_libraries(luci_env PRIVATE nncc_common)
install(TARGETS luci_env DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
target_link_libraries(luci_export PRIVATE locop)
target_link_libraries(luci_export PRIVATE oops)
install(TARGETS luci_export DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
#if(NOT ENABLE_TEST)
# return()
auto description = _builder.CreateString(description_str);
// Metadata
+ md._metadata.source_table(module->source_table());
auto metadata_vec = createCircleMetadataVector(_builder, md);
auto metadata = _builder.CreateVector(std::vector<Offset<Metadata>>(metadata_vec));
case loco::DataType::BOOL:
return circle::TensorType_BOOL;
+ case loco::DataType::STRING:
+ return circle::TensorType_STRING;
+
default:
INTERNAL_EXN_V("failed to convert unsupported loco::DataType", oops::to_uint32(type));
}
ctx.gd._operators.push_back(op_offset);
}
-class OperationExporter final : public luci::CircleNodeMutableVisitor<void>,
- public loco::CanonicalNodeMutableVisitor<void>
+class ExportHelper
{
public:
- OperationExporter(ExportContext &ctx) : _ctx{ctx}
+ ExportHelper(ExportContext &ctx) : _ctx{ctx}
{
// DO NOTHING
}
+protected:
+ /**
+ * @brief export simple nodes
+ */
+ void export_simple(loco::Node *node, circle::BuiltinOperator bop, circle::BuiltinOptions bot,
+ flatbuffers::Offset<void> options_offset)
+ {
+ export_node(_ctx, node, bop, bot, options_offset);
+ }
+
+ /**
+ * @brief export simple nodes having void options
+ */
+ void export_simple(loco::Node *node, circle::BuiltinOperator bop)
+ {
+ export_node(_ctx, node, bop);
+ }
+
+protected:
+ ExportContext &_ctx;
+};
+
+enum class OE
+{
+ ABC,
+ DEF,
+ GHIJ,
+ KLMN,
+ OPQR,
+ STUV,
+ WXYZ,
+ CIRC, // circle only
+ VIRT, // virtual
+};
+
+class OperationExporter final : public ExportHelper
+{
+public:
+ OperationExporter(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void export_node(luci::CircleNode *);
+};
+
+template <OE oe> class OpExporterLet;
+
+template <>
+class OpExporterLet<OE::ABC> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ // NOTE visit for luci::CircleNode is added NOT to throw NYI
+ void visit(luci::CircleNode *) final {}
+
public:
void visit(luci::CircleAbs *) final;
void visit(luci::CircleAdd *) final;
void visit(luci::CircleConv2D *) final;
void visit(luci::CircleCos *) final;
void visit(luci::CircleCustom *) final;
+};
+
+template <>
+class OpExporterLet<OE::DEF> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleDepthToSpace *) final;
void visit(luci::CircleDepthwiseConv2D *) final;
void visit(luci::CircleDequantize *) final;
void visit(luci::CircleFloorDiv *) final;
void visit(luci::CircleFloorMod *) final;
void visit(luci::CircleFullyConnected *) final;
+};
+
+template <>
+class OpExporterLet<OE::GHIJ> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleGather *) final;
void visit(luci::CircleGatherNd *) final;
void visit(luci::CircleGreater *) final;
void visit(luci::CircleGreaterEqual *) final;
void visit(luci::CircleIf *) final;
+};
+
+template <>
+class OpExporterLet<OE::KLMN> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleL2Normalize *) final;
void visit(luci::CircleL2Pool2D *) final;
void visit(luci::CircleLeakyRelu *) final;
void visit(luci::CircleNonMaxSuppressionV4 *) final;
void visit(luci::CircleNonMaxSuppressionV5 *) final;
void visit(luci::CircleNotEqual *) final;
+};
+
+template <>
+class OpExporterLet<OE::OPQR> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleOneHot *) final;
void visit(luci::CirclePack *) final;
void visit(luci::CirclePad *) final;
void visit(luci::CirclePadV2 *) final;
void visit(luci::CirclePow *) final;
void visit(luci::CirclePRelu *) final;
+ void visit(luci::CircleQuantize *) final;
void visit(luci::CircleRange *) final;
void visit(luci::CircleRank *) final;
void visit(luci::CircleReduceAny *) final;
void visit(luci::CircleReverseV2 *) final;
void visit(luci::CircleRound *) final;
void visit(luci::CircleRsqrt *) final;
+};
+
+template <>
+class OpExporterLet<OE::STUV> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleScatterNd *) final;
void visit(luci::CircleSegmentSum *) final;
void visit(luci::CircleSelect *) final;
void visit(luci::CircleUnidirectionalSequenceLSTM *) final;
void visit(luci::CircleUnique *) final;
void visit(luci::CircleUnpack *) final;
+};
+
+template <>
+class OpExporterLet<OE::WXYZ> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
void visit(luci::CircleWhere *) final;
void visit(luci::CircleWhile *) final;
void visit(luci::CircleZerosLike *) final;
+};
+
+template <>
+class OpExporterLet<OE::CIRC> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
// Circle only
void visit(luci::CircleBCQFullyConnected *) final;
void visit(luci::CircleBCQGather *) final;
void visit(luci::CircleInstanceNorm *) final;
+};
+
+template <>
+class OpExporterLet<OE::VIRT> final : public luci::CircleNodeMutableVisitor<void>,
+ public ExportHelper
+{
+public:
+ OpExporterLet(ExportContext &ctx) : ExportHelper(ctx)
+ {
+ // DO NOTHING
+ }
+
+public:
+ void visit(luci::CircleNode *) final {}
+
+public:
// Virtual
void visit(luci::CircleInput *) final {}
void visit(luci::CircleOutput *) final {}
void visit(luci::CircleUniqueOut *) final {}
void visit(luci::CircleUnpackOut *) final {}
void visit(luci::CircleWhileOut *) final {}
-
-private:
- /**
- * @brief export simple nodes
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop, circle::BuiltinOptions bot,
- flatbuffers::Offset<void> options_offset);
-
- /**
- * @brief export simple nodes having void options
- */
- void export_simple(loco::Node *node, circle::BuiltinOperator bop);
-
-private:
- ExportContext &_ctx;
};
-void OperationExporter::export_simple(loco::Node *node, circle::BuiltinOperator bop,
- circle::BuiltinOptions bot,
- flatbuffers::Offset<void> options_offset)
+void OperationExporter::export_node(luci::CircleNode *node)
{
- export_node(_ctx, node, bop, bot, options_offset);
-}
+ // TODO revise return type to bool and return if handled
+#define VISIT_OE(GRP) \
+ do \
+ { \
+ OpExporterLet<OE::GRP> oe(_ctx); \
+ node->accept(&oe); \
+ } while (false)
-void OperationExporter::export_simple(loco::Node *node, circle::BuiltinOperator bop)
-{
- export_node(_ctx, node, bop);
+ VISIT_OE(ABC);
+ VISIT_OE(DEF);
+ VISIT_OE(GHIJ);
+ VISIT_OE(KLMN);
+ VISIT_OE(OPQR);
+ VISIT_OE(STUV);
+ VISIT_OE(WXYZ);
+ VISIT_OE(CIRC);
+ VISIT_OE(VIRT);
+
+#undef VISIT_OE
}
-void OperationExporter::visit(luci::CircleAbs *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleAbs *node)
{
export_simple(node, circle::BuiltinOperator_ABS, circle::BuiltinOptions_AbsOptions,
CreateAbsOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleAdd *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleAdd *node)
{
export_simple(
node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions,
CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
-void OperationExporter::visit(luci::CircleAddN *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::ABC>::visit(luci::CircleAddN *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleArgMax *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleArgMax *node)
{
export_simple(
node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions,
CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
}
-void OperationExporter::visit(luci::CircleArgMin *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleArgMin *node)
{
export_simple(
node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions,
CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
}
-void OperationExporter::visit(luci::CircleAveragePool2D *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleAveragePool2D *node)
{
export_pool_2d<luci::CircleAveragePool2D>(_ctx, node, circle::BuiltinOperator_AVERAGE_POOL_2D);
}
-void OperationExporter::visit(luci::CircleBatchMatMul *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleBatchMatMul *node)
{
export_simple(node, circle::BuiltinOperator_BATCH_MATMUL,
circle::BuiltinOptions_BatchMatMulOptions,
CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union());
}
-void OperationExporter::visit(luci::CircleBidirectionalSequenceLSTM *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleBidirectionalSequenceLSTM *node)
{
auto bidi_lstm_outs = loco::succs(node);
assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
_ctx.gd._operators.push_back(op_offset);
}
-void OperationExporter::visit(luci::CircleCast *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::ABC>::visit(luci::CircleCast *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleCeil *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleCeil *node)
{
export_simple(node, circle::BuiltinOperator_CEIL);
}
-void OperationExporter::visit(luci::CircleConcatenation *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::ABC>::visit(luci::CircleConcatenation *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleBatchToSpaceND *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleBatchToSpaceND *node)
{
export_simple(node, circle::BuiltinOperator_BATCH_TO_SPACE_ND,
circle::BuiltinOptions_BatchToSpaceNDOptions,
CreateBatchToSpaceNDOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleConv2D *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleConv2D *node)
{
export_simple(node, circle::BuiltinOperator_CONV_2D, circle::BuiltinOptions_Conv2DOptions,
CreateConv2DOptions(_ctx.builder, getOpPadding(node->padding()),
.Union());
}
-void OperationExporter::visit(luci::CircleCos *node)
+void OpExporterLet<OE::ABC>::visit(luci::CircleCos *node)
{
export_simple(node, circle::BuiltinOperator_COS, circle::BuiltinOptions_CosOptions,
CreateCosOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleCustom *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::ABC>::visit(luci::CircleCustom *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleDepthToSpace *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleDepthToSpace *node)
{
export_simple(node, circle::BuiltinOperator_DEPTH_TO_SPACE,
circle::BuiltinOptions_DepthToSpaceOptions,
CreateDepthToSpaceOptions(_ctx.builder, node->block_size()).Union());
}
-void OperationExporter::visit(luci::CircleDepthwiseConv2D *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleDepthwiseConv2D *node)
{
export_simple(
node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleDequantize *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleDequantize *node)
{
export_simple(node, circle::BuiltinOperator_DEQUANTIZE);
}
-void OperationExporter::visit(luci::CircleDiv *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleDiv *node)
{
export_simple(
node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions,
CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
-void OperationExporter::visit(luci::CircleElu *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleElu *node)
{
export_simple(node, circle::BuiltinOperator_ELU);
}
-void OperationExporter::visit(luci::CircleEqual *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleEqual *node)
{
export_simple(node, circle::BuiltinOperator_EQUAL, circle::BuiltinOptions_EqualOptions,
CreateEqualOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleExp *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleExp *node)
{
export_simple(node, circle::BuiltinOperator_EXP, circle::BuiltinOptions_ExpOptions,
CreateExpOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleExpandDims *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleExpandDims *node)
{
export_simple(node, circle::BuiltinOperator_EXPAND_DIMS, circle::BuiltinOptions_ExpandDimsOptions,
CreateExpandDimsOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleFakeQuant *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFakeQuant *node)
{
export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions,
CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(),
.Union());
}
-void OperationExporter::visit(luci::CircleFill *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFill *node)
{
export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions,
CreateFillOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleFloor *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFloor *node)
{
export_simple(node, circle::BuiltinOperator_FLOOR);
}
-void OperationExporter::visit(luci::CircleFloorDiv *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFloorDiv *node)
{
export_simple(node, circle::BuiltinOperator_FLOOR_DIV, circle::BuiltinOptions_FloorDivOptions,
CreateFloorDivOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleFloorMod *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFloorMod *node)
{
export_simple(node, circle::BuiltinOperator_FLOOR_MOD, circle::BuiltinOptions_FloorModOptions,
CreateFloorModOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleFullyConnected *node)
+void OpExporterLet<OE::DEF>::visit(luci::CircleFullyConnected *node)
{
export_simple(
node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleGather *node)
+void OpExporterLet<OE::GHIJ>::visit(luci::CircleGather *node)
{
export_simple(node, circle::BuiltinOperator_GATHER, circle::BuiltinOptions_GatherOptions,
CreateGatherOptions(_ctx.builder, node->axis()).Union());
}
-void OperationExporter::visit(luci::CircleGatherNd *node)
+void OpExporterLet<OE::GHIJ>::visit(luci::CircleGatherNd *node)
{
export_simple(node, circle::BuiltinOperator_GATHER_ND, circle::BuiltinOptions_GatherNdOptions,
CreateGatherNdOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleGreater *node)
+void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreater *node)
{
export_simple(node, circle::BuiltinOperator_GREATER, circle::BuiltinOptions_GreaterOptions,
CreateGreaterOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleGreaterEqual *node)
+void OpExporterLet<OE::GHIJ>::visit(luci::CircleGreaterEqual *node)
{
export_simple(node, circle::BuiltinOperator_GREATER_EQUAL,
circle::BuiltinOptions_GreaterEqualOptions,
CreateGreaterEqualOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleIf *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::GHIJ>::visit(luci::CircleIf *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleL2Normalize *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Normalize *node)
{
export_simple(
node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions,
CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
-void OperationExporter::visit(luci::CircleL2Pool2D *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleL2Pool2D *node)
{
export_pool_2d<luci::CircleL2Pool2D>(_ctx, node, circle::BuiltinOperator_L2_POOL_2D);
}
-void OperationExporter::visit(luci::CircleLeakyRelu *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLeakyRelu *node)
{
export_simple(node, circle::BuiltinOperator_LEAKY_RELU, circle::BuiltinOptions_LeakyReluOptions,
CreateLeakyReluOptions(_ctx.builder, node->alpha()).Union());
}
-void OperationExporter::visit(luci::CircleLess *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLess *node)
{
export_simple(node, circle::BuiltinOperator_LESS, circle::BuiltinOptions_LessOptions,
CreateLessOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleLessEqual *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLessEqual *node)
{
export_simple(node, circle::BuiltinOperator_LESS_EQUAL, circle::BuiltinOptions_LessEqualOptions,
CreateLessEqualOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleLocalResponseNormalization *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLocalResponseNormalization *node)
{
export_simple(node, circle::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
circle::BuiltinOptions_LocalResponseNormalizationOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleLog *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLog *node)
{
export_simple(node, circle::BuiltinOperator_LOG);
}
-void OperationExporter::visit(luci::CircleLogicalAnd *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalAnd *node)
{
export_simple(node, circle::BuiltinOperator_LOGICAL_AND, circle::BuiltinOptions_LogicalAndOptions,
CreateLogicalAndOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleLogicalNot *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalNot *node)
{
export_simple(node, circle::BuiltinOperator_LOGICAL_NOT, circle::BuiltinOptions_LogicalNotOptions,
CreateLogicalNotOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleLogicalOr *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLogicalOr *node)
{
export_simple(node, circle::BuiltinOperator_LOGICAL_OR, circle::BuiltinOptions_LogicalOrOptions,
CreateLogicalOrOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleLogistic *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLogistic *node)
{
export_simple(node, circle::BuiltinOperator_LOGISTIC);
}
-void OperationExporter::visit(luci::CircleLogSoftmax *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleLogSoftmax *node)
{
export_simple(node, circle::BuiltinOperator_LOG_SOFTMAX, circle::BuiltinOptions_LogSoftmaxOptions,
CreateLogSoftmaxOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleMatrixDiag *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixDiag *node)
{
export_simple(node, circle::BuiltinOperator_MATRIX_DIAG, circle::BuiltinOptions_MatrixDiagOptions,
CreateMatrixDiagOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleMatrixSetDiag *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMatrixSetDiag *node)
{
export_simple(node, circle::BuiltinOperator_MATRIX_SET_DIAG,
circle::BuiltinOptions_MatrixSetDiagOptions,
CreateMatrixSetDiagOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleMaximum *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMaximum *node)
{
export_simple(node, circle::BuiltinOperator_MAXIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
CreateMaximumMinimumOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleMaxPool2D *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMaxPool2D *node)
{
export_pool_2d<luci::CircleMaxPool2D>(_ctx, node, circle::BuiltinOperator_MAX_POOL_2D);
}
-void OperationExporter::visit(luci::CircleMean *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMean *node)
{
export_simple(node, circle::BuiltinOperator_MEAN, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleMinimum *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMinimum *node)
{
export_simple(node, circle::BuiltinOperator_MINIMUM, circle::BuiltinOptions_MaximumMinimumOptions,
CreateMaximumMinimumOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleMirrorPad *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMirrorPad *node)
{
export_simple(
node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions,
CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union());
}
-void OperationExporter::visit(luci::CircleMul *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleMul *node)
{
export_simple(
node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions,
CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
-void OperationExporter::visit(luci::CircleNeg *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleNeg *node)
{
export_simple(node, circle::BuiltinOperator_NEG, circle::BuiltinOptions_NegOptions,
CreateNegOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleNonMaxSuppressionV4 *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV4 *node)
+{
+ export_node(_ctx, node);
+}
-void OperationExporter::visit(luci::CircleNonMaxSuppressionV5 *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::KLMN>::visit(luci::CircleNonMaxSuppressionV5 *node)
+{
+ export_node(_ctx, node);
+}
-void OperationExporter::visit(luci::CircleNotEqual *node)
+void OpExporterLet<OE::KLMN>::visit(luci::CircleNotEqual *node)
{
export_simple(node, circle::BuiltinOperator_NOT_EQUAL, circle::BuiltinOptions_NotEqualOptions,
CreateNotEqualOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleOneHot *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleOneHot *node)
{
export_simple(node, circle::BuiltinOperator_ONE_HOT, circle::BuiltinOptions_OneHotOptions,
CreateOneHotOptions(_ctx.builder, node->axis()).Union());
}
-void OperationExporter::visit(luci::CirclePack *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CirclePack *node)
{
export_simple(node, circle::BuiltinOperator_PACK, circle::BuiltinOptions_PackOptions,
CreatePackOptions(_ctx.builder, node->values_count(), node->axis()).Union());
}
-void OperationExporter::visit(luci::CirclePad *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CirclePad *node)
{
export_simple(node, circle::BuiltinOperator_PAD, circle::BuiltinOptions_PadOptions,
CreatePadOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CirclePadV2 *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CirclePadV2 *node)
{
export_simple(node, circle::BuiltinOperator_PADV2, circle::BuiltinOptions_PadV2Options,
CreatePadV2Options(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CirclePow *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CirclePow *node)
{
export_simple(node, circle::BuiltinOperator_POW, circle::BuiltinOptions_PowOptions,
CreatePowOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CirclePRelu *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CirclePRelu *node)
{
export_simple(node, circle::BuiltinOperator_PRELU);
}
-void OperationExporter::visit(luci::CircleRange *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleQuantize *node)
+{
+ export_simple(node, circle::BuiltinOperator_QUANTIZE);
+}
+
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRange *node)
{
export_simple(node, circle::BuiltinOperator_RANGE, circle::BuiltinOptions_RangeOptions,
CreateRangeOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleRank *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRank *node)
{
export_simple(node, circle::BuiltinOperator_RANK, circle::BuiltinOptions_RankOptions,
CreateRankOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleReduceAny *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceAny *node)
{
export_simple(node, circle::BuiltinOperator_REDUCE_ANY, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleReduceMax *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMax *node)
{
export_simple(node, circle::BuiltinOperator_REDUCE_MAX, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleReduceMin *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceMin *node)
{
export_simple(node, circle::BuiltinOperator_REDUCE_MIN, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleReduceProd *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReduceProd *node)
{
export_simple(node, circle::BuiltinOperator_REDUCE_PROD, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleRelu *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu *node)
{
export_simple(node, circle::BuiltinOperator_RELU);
}
-void OperationExporter::visit(luci::CircleRelu6 *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRelu6 *node)
{
export_simple(node, circle::BuiltinOperator_RELU6);
}
-void OperationExporter::visit(luci::CircleReluN1To1 *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReluN1To1 *node)
{
export_simple(node, circle::BuiltinOperator_RELU_N1_TO_1);
}
-void OperationExporter::visit(luci::CircleReshape *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReshape *node)
{
auto new_shape = _ctx.builder.CreateVector<int32_t>(
node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
CreateReshapeOptions(_ctx.builder, new_shape).Union());
}
-void OperationExporter::visit(luci::CircleResizeBilinear *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeBilinear *node)
{
export_simple(
node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleResizeNearestNeighbor *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleResizeNearestNeighbor *node)
{
export_simple(node, circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
circle::BuiltinOptions_ResizeNearestNeighborOptions,
CreateResizeNearestNeighborOptions(_ctx.builder, node->align_corners()).Union());
}
-void OperationExporter::visit(luci::CircleReverseSequence *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseSequence *node)
{
export_simple(
node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions,
CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union());
}
-void OperationExporter::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::OPQR>::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleRound *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRound *node)
{
export_simple(node, circle::BuiltinOperator_ROUND);
}
-void OperationExporter::visit(luci::CircleRsqrt *node)
+void OpExporterLet<OE::OPQR>::visit(luci::CircleRsqrt *node)
{
export_simple(node, circle::BuiltinOperator_RSQRT);
}
-void OperationExporter::visit(luci::CircleScatterNd *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleScatterNd *node)
{
export_simple(node, circle::BuiltinOperator_SCATTER_ND, circle::BuiltinOptions_ScatterNdOptions,
CreateScatterNdOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSegmentSum *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSegmentSum *node)
{
export_simple(node, circle::BuiltinOperator_SEGMENT_SUM, circle::BuiltinOptions_SegmentSumOptions,
CreateSegmentSumOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSelect *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSelect *node)
{
export_simple(node, circle::BuiltinOperator_SELECT, circle::BuiltinOptions_SelectOptions,
CreateSelectOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSelectV2 *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSelectV2 *node)
{
export_simple(node, circle::BuiltinOperator_SELECT_V2, circle::BuiltinOptions_SelectV2Options,
CreateSelectV2Options(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleShape *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleShape *node)
{
export_simple(node, circle::BuiltinOperator_SHAPE, circle::BuiltinOptions_ShapeOptions,
CreateShapeOptions(_ctx.builder, to_circle_tensortype(node->out_type())).Union());
}
-void OperationExporter::visit(luci::CircleSin *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSin *node)
{
export_simple(node, circle::BuiltinOperator_SIN);
}
-void OperationExporter::visit(luci::CircleSlice *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSlice *node)
{
export_simple(node, circle::BuiltinOperator_SLICE, circle::BuiltinOptions_SliceOptions,
CreateSliceOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSoftmax *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSoftmax *node)
{
export_simple(node, circle::BuiltinOperator_SOFTMAX, circle::BuiltinOptions_SoftmaxOptions,
CreateSoftmaxOptions(_ctx.builder, node->beta()).Union());
}
-void OperationExporter::visit(luci::CircleSpaceToBatchND *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToBatchND *node)
{
export_simple(node, circle::BuiltinOperator_SPACE_TO_BATCH_ND,
circle::BuiltinOptions_SpaceToBatchNDOptions,
CreateSpaceToBatchNDOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSpaceToDepth *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSpaceToDepth *node)
{
export_simple(node, circle::BuiltinOperator_SPACE_TO_DEPTH,
circle::BuiltinOptions_SpaceToDepthOptions,
CreateSpaceToDepthOptions(_ctx.builder, node->block_size()).Union());
}
-void OperationExporter::visit(luci::CircleSparseToDense *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSparseToDense *node)
{
export_simple(node, circle::BuiltinOperator_SPARSE_TO_DENSE,
circle::BuiltinOptions_SparseToDenseOptions,
CreateSparseToDenseOptions(_ctx.builder, node->validate_indices()).Union());
}
-void OperationExporter::visit(luci::CircleSplit *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::STUV>::visit(luci::CircleSplit *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleSplitV *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::STUV>::visit(luci::CircleSplitV *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleSqrt *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSqrt *node)
{
export_simple(node, circle::BuiltinOperator_SQRT);
}
-void OperationExporter::visit(luci::CircleSquare *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSquare *node)
{
export_simple(node, circle::BuiltinOperator_SQUARE, circle::BuiltinOptions_SquareOptions,
CreateSquareOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSquaredDifference *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSquaredDifference *node)
{
export_simple(node, circle::BuiltinOperator_SQUARED_DIFFERENCE,
circle::BuiltinOptions_SquaredDifferenceOptions,
CreateSquaredDifferenceOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleSqueeze *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSqueeze *node)
{
auto squeeze_dims = _ctx.builder.CreateVector<int32_t>(node->squeeze_dims());
export_simple(node, circle::BuiltinOperator_SQUEEZE, circle::BuiltinOptions_SqueezeOptions,
CreateSqueezeOptions(_ctx.builder, squeeze_dims).Union());
}
-void OperationExporter::visit(luci::CircleStridedSlice *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleStridedSlice *node)
{
export_simple(node, circle::BuiltinOperator_STRIDED_SLICE,
circle::BuiltinOptions_StridedSliceOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleSub *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSub *node)
{
export_simple(
node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions,
CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
-void OperationExporter::visit(luci::CircleSum *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleSum *node)
{
export_simple(node, circle::BuiltinOperator_SUM, circle::BuiltinOptions_ReducerOptions,
CreateReducerOptions(_ctx.builder, node->keep_dims()).Union());
}
-void OperationExporter::visit(luci::CircleTanh *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleTanh *node)
{
export_simple(node, circle::BuiltinOperator_TANH);
}
-void OperationExporter::visit(luci::CircleTile *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleTile *node)
{
export_simple(node, circle::BuiltinOperator_TILE, circle::BuiltinOptions_TileOptions,
CreateTileOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleTopKV2 *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::STUV>::visit(luci::CircleTopKV2 *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleTranspose *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleTranspose *node)
{
export_simple(node, circle::BuiltinOperator_TRANSPOSE, circle::BuiltinOptions_TransposeOptions,
CreateTransposeOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleTransposeConv *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleTransposeConv *node)
{
export_simple(node, circle::BuiltinOperator_TRANSPOSE_CONV,
circle::BuiltinOptions_TransposeConvOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleUnidirectionalSequenceLSTM *node)
+void OpExporterLet<OE::STUV>::visit(luci::CircleUnidirectionalSequenceLSTM *node)
{
export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleUnique *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::STUV>::visit(luci::CircleUnique *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleUnpack *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::STUV>::visit(luci::CircleUnpack *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleWhere *node)
+void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhere *node)
{
export_simple(node, circle::BuiltinOperator_WHERE, circle::BuiltinOptions_WhereOptions,
CreateWhereOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleWhile *node) { export_node(_ctx, node); }
+void OpExporterLet<OE::WXYZ>::visit(luci::CircleWhile *node) { export_node(_ctx, node); }
-void OperationExporter::visit(luci::CircleZerosLike *node)
+void OpExporterLet<OE::WXYZ>::visit(luci::CircleZerosLike *node)
{
export_simple(node, circle::BuiltinOperator_ZEROS_LIKE, circle::BuiltinOptions_ZerosLikeOptions,
CreateZerosLikeOptions(_ctx.builder).Union());
}
-void OperationExporter::visit(luci::CircleBCQFullyConnected *node)
+void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQFullyConnected *node)
{
export_simple(node, circle::BuiltinOperator_BCQ_FULLY_CONNECTED,
circle::BuiltinOptions_BCQFullyConnectedOptions,
.Union());
}
-void OperationExporter::visit(luci::CircleBCQGather *node)
+void OpExporterLet<OE::CIRC>::visit(luci::CircleBCQGather *node)
{
export_simple(
node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions,
CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union());
}
-void OperationExporter::visit(luci::CircleInstanceNorm *node)
+void OpExporterLet<OE::CIRC>::visit(luci::CircleInstanceNorm *node)
{
export_simple(node, circle::BuiltinOperator_INSTANCE_NORM,
circle::BuiltinOptions_InstanceNormOptions,
const auto ops_size = gd._operators.size();
- circle_node->accept(&exporter);
+ exporter.export_node(circle_node);
if (has_origin(circle_node) && ops_size != gd._operators.size())
{
const auto node_id = gd._operators.size() - 1;
for (auto source : get_origin(circle_node)->sources())
{
- md._metadata.add_source_table(source->id(), source->name());
md._metadata.add_op_table(node_id, source->id());
}
}
#include <loco/IR/DataTypeTraits.h>
#include <oops/InternalExn.h>
+#include <string.h>
+
using namespace circle;
using namespace flatbuffers;
}
template <>
+flatbuffers::Offset<circle::Buffer>
+encodeOpBufferByDType<loco::DataType::STRING>(FlatBufferBuilder &builder, luci::CircleConst *c)
+{
+ const uint32_t count = c->size<loco::DataType::STRING>();
+ uint32_t raw_size = sizeof(int32_t) * (count + 2);
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ auto &value = c->at<loco::DataType::STRING>(i);
+ raw_size += value.length();
+ }
+
+ // serialize string data
+ // int32_t count
+ // int32_t offsets[count + 1]
+ // string values[count]
+ std::vector<uint8_t> raw_data;
+ raw_data.reserve(raw_size);
+
+ auto *i32d = reinterpret_cast<int32_t *>(raw_data.data());
+ int32_t start = sizeof(int32_t) * (count + 2);
+ int32_t offset = start;
+ std::vector<int32_t> offsets;
+
+ *i32d++ = count;
+ *i32d++ = start;
+ offsets.push_back(start);
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ auto &value = c->at<loco::DataType::STRING>(i);
+ offset += value.length();
+ *i32d++ = offset;
+ offsets.push_back(offset);
+ }
+
+ auto *data = reinterpret_cast<uint8_t *>(i32d);
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ int32_t length = offsets[i + 1] - offsets[i];
+ auto &value = c->at<loco::DataType::STRING>(i);
+ memcpy(data, value.c_str(), length);
+ data += length;
+ }
+
+ auto array_offset = builder.CreateVector(reinterpret_cast<uint8_t *>(raw_data.data()), raw_size);
+ return CreateBuffer(builder, array_offset);
+}
+
+template <>
flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder, luci::CircleConst *c)
{
switch (c->dtype())
return encodeOpBufferByDType<loco::DataType::U8>(builder, c);
case loco::DataType::BOOL:
return encodeOpBufferByDType<loco::DataType::BOOL>(builder, c);
+ case loco::DataType::STRING:
+ return encodeOpBufferByDType<loco::DataType::STRING>(builder, c);
default:
break;
}
}
else
{
- // When there is no CircleConst, the operation do not use buffer.
- // So return buffer id as 0 which means empty buffer in circle schema.
- return 0;
+ // When there is no CircleConst, there is nothing to cache.
+ // So return new buffer id.
+ auto buffer = encodeOpBuffer(builder);
+
+ auto buffer_id = static_cast<uint32_t>(md._buffers.size());
+ md._buffers.push_back(buffer);
+
+ return buffer_id;
}
}
class CircleExportMetadata
{
public:
- void add_source_table(uint32_t source_id, std::string origin_name)
- {
- // Model with multiple subgraph may have different origin_name
- // even if source_id is same. However, as we do not consider about
- // multiple subgraph in profiling for now, just do not care those cases
- // and support them correctly in the future.
- _source_table.emplace(source_id, origin_name);
- }
+ void source_table(const std::map<uint32_t, std::string> &table) { _source_table = table; }
void add_op_table(uint32_t node_id, uint32_t source_id)
{
target_link_libraries(luci_import PRIVATE locop)
target_link_libraries(luci_import PRIVATE oops)
install(TARGETS luci_import DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
#include "Nodes/CirclePadV2.h"
#include "Nodes/CirclePow.h"
#include "Nodes/CirclePRelu.h"
+#include "Nodes/CircleQuantize.h"
#include "Nodes/CircleRange.h"
#include "Nodes/CircleRank.h"
#include "Nodes/CircleReduceAny.h"
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_QUANTIZE_H__
+#define __LUCI_IMPORT_OP_CIRCLE_QUANTIZE_H__
+
+#include "luci/Import/GraphBuilder.h"
+
+namespace luci
+{
+
+class CircleQuantizeGraphBuilder : public GraphBuilder
+{
+public:
+ bool validate(const ValidateArgs &args) const final;
+
+private:
+ CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_QUANTIZE_H__
*/
const OriginTable origin_table(void);
+ const std::map<uint32_t, std::string> &source_table(void) const { return _source_table; }
+
private:
// Decoded metadata is stored
std::map<uint32_t, std::string> _source_table;
case circle::TensorType_INT64:
return loco::DataType::S64;
case circle::TensorType_STRING:
- break;
+ return loco::DataType::STRING;
case circle::TensorType_BOOL:
return loco::DataType::BOOL;
case circle::TensorType_INT16:
auto *node = build_node(bna);
uint32_t output_count = outputs.size();
- assert(output_count > 0);
+ // NOTE CustomOp inherits GraphBuilderMultiOutput and can have 0 output
+ if (output_count > 0)
{
// Let's use attributes from output 0 for this node
const circle::TensorT &output_tensor = *tensors[outputs[0]];
auto *nodeout = build_out(boa);
copy_tensor_attributes(output_tensor, nodeout);
+ // NOTE name of CxxxOut nodes may have same name
// mark shape_status
if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
nodeout->shape_status(ShapeStatus::NOSHAPE);
CIRCLE_NODE(PADV2, CirclePadV2GraphBuilder); // 60
CIRCLE_NODE(POW, CirclePowGraphBuilder); // 78
CIRCLE_NODE(PRELU, CirclePReluGraphBuilder); // 54,
+ CIRCLE_NODE(QUANTIZE, CircleQuantizeGraphBuilder); // 114,
CIRCLE_NODE(RANGE, CircleRangeGraphBuilder); // 96
CIRCLE_NODE(RANK, CircleRankGraphBuilder); // 110
CIRCLE_NODE(REDUCE_ANY, CircleReduceAnyGraphBuilder); // 91
// BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46,
// BuiltinOperator_DELEGATE = 51,
// BuiltinOperator_ARG_MAX = 56,
- // BuiltinOperator_QUANTIZE = 114,
// BuiltinOperator_HARD_SWISH = 117,
// BuiltinOperator_DENSIFY = 124,
}
post_import_graph(module.get(), reader);
+ // Initialize 'source_table'
+ auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
+ if (circle_metadata->source_table().size() > 0)
+ {
+ // If there is 'source_table' metadata in circle model, copy the table.
+ module->source_table(circle_metadata->source_table());
+ }
+ else
+ {
+ // If there is no 'source_table' metadata in circle model,
+ // create new table with circle nodes.
+ std::map<uint32_t, std::string> table;
+
+ // NOTE Only first subgraph is considered
+ for (auto node : loco::all_nodes(module->graph(0)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ // Virtual nodes may not have id
+ if (!has_node_id(circle_node))
+ continue;
+
+ assert(table.find(get_node_id(circle_node)) == table.end());
+ table.insert({get_node_id(circle_node), circle_node->name()});
+ }
+
+ module->source_table(table);
+ }
+
return module;
}
#include <oops/UserExn.h>
#include <cassert>
+#include <ostream>
+#include <string>
+#include <vector>
namespace
{
return os;
}
-} // namespace
-
-namespace luci
-{
+using namespace luci;
template <loco::DataType DT>
-static void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
- CircleConst *const_node)
+void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements, CircleConst *const_node)
{
using T = typename loco::DataTypeImpl<DT>::Type;
}
}
+template <>
+void copy_data<loco::DataType::STRING>(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
+ CircleConst *const_node)
+{
+ assert(const_node->sparsityparam() == nullptr);
+
+ const auto *data = reinterpret_cast<const char *>(raw_data.data());
+ const auto *i32d = reinterpret_cast<const int32_t *>(raw_data.data());
+
+ // de-serialize string data
+ // int32_t count
+ // int32_t offsets[count + 1]
+ // string values[count]
+ assert(static_cast<uint32_t>(*i32d) == num_elements);
+ i32d++; // skip count
+
+ std::vector<int32_t> offsets;
+ offsets.push_back(*i32d++);
+ for (uint32_t i = 0; i < num_elements; ++i)
+ {
+ offsets.push_back(*i32d++);
+ }
+ assert(offsets.size() == num_elements + 1);
+
+ const_node->size<loco::DataType::STRING>(num_elements);
+ for (uint32_t i = 0; i < num_elements; ++i)
+ {
+ int32_t start = offsets[i];
+ int32_t next = offsets[i + 1];
+
+ std::string value(data + start, next - start);
+ const_node->at<loco::DataType::STRING>(i) = value;
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
{
LOGGER(l);
copy_data<loco::DataType::BOOL>(buffer, num_elements, const_node);
break;
+ case loco::DataType::STRING:
+ copy_data<loco::DataType::STRING>(buffer, num_elements, const_node);
+ break;
+
default:
throw oops::UserExn("Unsupported tensor type",
circle::EnumNameTensorType(const_tensor.type));
switch (tensor->type)
{
+ case circle::TensorType_FLOAT64:
+ break;
case circle::TensorType_FLOAT32:
break;
+ case circle::TensorType_INT16:
+ break;
+ case circle::TensorType_UINT8:
+ break;
default:
return false;
}
case circle::TensorType_FLOAT32:
case circle::TensorType_FLOAT64:
break;
+ // Additional support for quantized tensors
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
+ break;
// TODO support TensorType_COMPLEX64, complex128, bfloat16
default:
return false;
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleQuantize.h"
+
+#include <luci/IR/Nodes/CircleQuantize.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleQuantizeGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleQuantizeGraphBuilder::build_node(const circle::OperatorT &,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleQuantize>();
+ node->input(inputs.at(0));
+
+ // No options for Quantize
+
+ return node;
+}
+
+} // namespace luci
case circle::TensorType_COMPLEX64:
break;
// TODO support bfloat16, complex128
+ // Additional support for quantized tensors
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
+ break;
default:
return false;
}
target_link_libraries(luci_lang PRIVATE nncc_common)
install(TARGETS luci_lang DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
*
* TODO Deprecated this class, and use loco::FixedArity instead
*/
-template <unsigned N, typename Base> class FixedArityNode : public Base
+template <uint32_t N, typename Base> class FixedArityNode : public Base
{
public:
FixedArityNode()
virtual ~FixedArityNode() = default;
public:
- unsigned arity(void) const final { return N; }
+ uint32_t arity(void) const final { return N; }
loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
protected:
// This API allows inherited classes to access "_args" field.
- loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+ loco::Use *at(uint32_t n) const { return _args.at(n).get(); }
private:
std::vector<std::unique_ptr<loco::Use>> _args{};
#include "Nodes/CirclePad.h"
#include "Nodes/CirclePadV2.h"
#include "Nodes/CirclePow.h"
+#include "Nodes/CircleQuantize.h"
#include "Nodes/CirclePRelu.h"
#include "Nodes/CircleRange.h"
#include "Nodes/CircleRank.h"
// .Input("value: T") <-- Input name is 'value'
//
-CIRCLE_NODE(ABS, luci::CircleAbs)
-CIRCLE_NODE(ADD, luci::CircleAdd)
-CIRCLE_NODE(ADD_N, luci::CircleAddN)
-CIRCLE_NODE(ARG_MAX, luci::CircleArgMax)
-CIRCLE_NODE(ARG_MIN, luci::CircleArgMin)
-CIRCLE_NODE(AVERAGE_POOL_2D, luci::CircleAveragePool2D)
-CIRCLE_NODE(BATCH_TO_SPACE_ND, luci::CircleBatchToSpaceND)
-CIRCLE_NODE(BATCH_MATMUL, luci::CircleBatchMatMul)
-CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, luci::CircleBidirectionalSequenceLSTM)
-CIRCLE_NODE(CAST, luci::CircleCast)
-CIRCLE_NODE(CEIL, luci::CircleCeil)
-CIRCLE_NODE(CONCATENATION, luci::CircleConcatenation)
-CIRCLE_NODE(CONV_2D, luci::CircleConv2D)
-CIRCLE_NODE(COS, luci::CircleCos)
-CIRCLE_NODE(CUSTOM, luci::CircleCustom)
-CIRCLE_NODE(DEPTH_TO_SPACE, luci::CircleDepthToSpace)
-CIRCLE_NODE(DEPTHWISE_CONV_2D, luci::CircleDepthwiseConv2D)
-CIRCLE_NODE(DEQUANTIZE, luci::CircleDequantize)
-CIRCLE_NODE(DIV, luci::CircleDiv)
-CIRCLE_NODE(ELU, luci::CircleElu)
-CIRCLE_NODE(EQUAL, luci::CircleEqual)
-CIRCLE_NODE(EXP, luci::CircleExp)
-CIRCLE_NODE(EXPAND_DIMS, luci::CircleExpandDims)
-CIRCLE_NODE(FAKE_QUANT, luci::CircleFakeQuant)
-CIRCLE_NODE(FILL, luci::CircleFill)
-CIRCLE_NODE(FLOOR, luci::CircleFloor)
-CIRCLE_NODE(FLOOR_DIV, luci::CircleFloorDiv)
-CIRCLE_NODE(FLOOR_MOD, luci::CircleFloorMod)
-CIRCLE_NODE(FULLY_CONNECTED, luci::CircleFullyConnected)
-CIRCLE_NODE(GATHER, luci::CircleGather)
-CIRCLE_NODE(GATHER_ND, luci::CircleGatherNd)
-CIRCLE_NODE(GREATER, luci::CircleGreater)
-CIRCLE_NODE(GREATER_EQUAL, luci::CircleGreaterEqual)
-CIRCLE_NODE(IF, luci::CircleIf)
-CIRCLE_NODE(L2_NORMALIZATION, luci::CircleL2Normalize)
-CIRCLE_NODE(L2_POOL_2D, luci::CircleL2Pool2D)
-CIRCLE_NODE(LEAKY_RELU, luci::CircleLeakyRelu)
-CIRCLE_NODE(LESS, luci::CircleLess)
-CIRCLE_NODE(LESS_EQUAL, luci::CircleLessEqual)
-CIRCLE_NODE(LOCAL_RESPONSE_NORMALIZATION, luci::CircleLocalResponseNormalization)
-CIRCLE_NODE(LOG, luci::CircleLog)
-CIRCLE_NODE(LOGICAL_AND, luci::CircleLogicalAnd)
-CIRCLE_NODE(LOGICAL_NOT, luci::CircleLogicalNot)
-CIRCLE_NODE(LOGICAL_OR, luci::CircleLogicalOr)
-CIRCLE_NODE(LOGISTIC, luci::CircleLogistic)
-CIRCLE_NODE(LOG_SOFTMAX, luci::CircleLogSoftmax)
-CIRCLE_NODE(MATRIX_DIAG, luci::CircleMatrixDiag)
-CIRCLE_NODE(MAX_POOL_2D, luci::CircleMaxPool2D)
-CIRCLE_NODE(MATRIX_SET_DIAG, luci::CircleMatrixSetDiag)
-CIRCLE_NODE(MAXIMUM, luci::CircleMaximum)
-CIRCLE_NODE(MEAN, luci::CircleMean)
-CIRCLE_NODE(MINIMUM, luci::CircleMinimum)
-CIRCLE_NODE(MIRROR_PAD, luci::CircleMirrorPad)
-CIRCLE_NODE(MUL, luci::CircleMul)
-CIRCLE_NODE(NEG, luci::CircleNeg)
-CIRCLE_NODE(NON_MAX_SUPPRESSION_V4, luci::CircleNonMaxSuppressionV4)
-CIRCLE_NODE(NON_MAX_SUPPRESSION_V5, luci::CircleNonMaxSuppressionV5)
-CIRCLE_NODE(NOT_EQUAL, luci::CircleNotEqual)
-CIRCLE_NODE(ONE_HOT, luci::CircleOneHot)
-CIRCLE_NODE(PACK, luci::CirclePack)
-CIRCLE_NODE(PAD, luci::CirclePad)
-CIRCLE_NODE(PADV2, luci::CirclePadV2)
-CIRCLE_NODE(POW, luci::CirclePow)
-CIRCLE_NODE(PRELU, luci::CirclePRelu)
-CIRCLE_NODE(RANGE, luci::CircleRange)
-CIRCLE_NODE(RANK, luci::CircleRank)
-CIRCLE_NODE(REDUCE_ANY, luci::CircleReduceAny)
-CIRCLE_NODE(REDUCE_MAX, luci::CircleReduceMax)
-CIRCLE_NODE(REDUCE_MIN, luci::CircleReduceMin)
-CIRCLE_NODE(REDUCE_PROD, luci::CircleReduceProd)
-CIRCLE_NODE(RELU, luci::CircleRelu)
-CIRCLE_NODE(RELU6, luci::CircleRelu6)
-CIRCLE_NODE(RELU_N1_TO_1, luci::CircleReluN1To1)
-CIRCLE_NODE(RESHAPE, luci::CircleReshape)
-CIRCLE_NODE(RESIZE_BILINEAR, luci::CircleResizeBilinear)
-CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, luci::CircleResizeNearestNeighbor)
-CIRCLE_NODE(REVERSE_SEQUENCE, luci::CircleReverseSequence)
-CIRCLE_NODE(REVERSE_V2, luci::CircleReverseV2)
-CIRCLE_NODE(ROUND, luci::CircleRound)
-CIRCLE_NODE(RSQRT, luci::CircleRsqrt)
-CIRCLE_NODE(SCATTER_ND, luci::CircleScatterNd)
-CIRCLE_NODE(SEGMENT_SUM, luci::CircleSegmentSum)
-CIRCLE_NODE(SELECT, luci::CircleSelect)
-CIRCLE_NODE(SELECT_V2, luci::CircleSelectV2)
-CIRCLE_NODE(SHAPE, luci::CircleShape)
-CIRCLE_NODE(SIN, luci::CircleSin)
-CIRCLE_NODE(SLICE, luci::CircleSlice)
-CIRCLE_NODE(SOFTMAX, luci::CircleSoftmax)
-CIRCLE_NODE(SPACE_TO_BATCH_ND, luci::CircleSpaceToBatchND)
-CIRCLE_NODE(SPACE_TO_DEPTH, luci::CircleSpaceToDepth)
-CIRCLE_NODE(SPARSE_TO_DENSE, luci::CircleSparseToDense)
-CIRCLE_NODE(SPLIT, luci::CircleSplit)
-CIRCLE_NODE(SPLIT_V, luci::CircleSplitV)
-CIRCLE_NODE(SQRT, luci::CircleSqrt)
-CIRCLE_NODE(SQUARE, luci::CircleSquare)
-CIRCLE_NODE(SQUARED_DIFFERENCE, luci::CircleSquaredDifference)
-CIRCLE_NODE(SQUEEZE, luci::CircleSqueeze)
-CIRCLE_NODE(STRIDED_SLICE, luci::CircleStridedSlice)
-CIRCLE_NODE(SUB, luci::CircleSub)
-CIRCLE_NODE(SUM, luci::CircleSum)
-CIRCLE_NODE(TANH, luci::CircleTanh)
-CIRCLE_NODE(TILE, luci::CircleTile)
-CIRCLE_NODE(TOPK_V2, luci::CircleTopKV2)
-CIRCLE_NODE(TRANSPOSE, luci::CircleTranspose)
-CIRCLE_NODE(TRANSPOSE_CONV, luci::CircleTransposeConv)
-CIRCLE_NODE(UNIDIRECTIONAL_SEQUENCE_LSTM, luci::CircleUnidirectionalSequenceLSTM)
-CIRCLE_NODE(UNIQUE, luci::CircleUnique)
-CIRCLE_NODE(UNPACK, luci::CircleUnpack)
-CIRCLE_NODE(WHERE, luci::CircleWhere)
-CIRCLE_NODE(WHILE, luci::CircleWhile)
-CIRCLE_NODE(ZEROS_LIKE, luci::CircleZerosLike)
+CIRCLE_NODE(ABS, CircleAbs)
+CIRCLE_NODE(ADD, CircleAdd)
+CIRCLE_NODE(ADD_N, CircleAddN)
+CIRCLE_NODE(ARG_MAX, CircleArgMax)
+CIRCLE_NODE(ARG_MIN, CircleArgMin)
+CIRCLE_NODE(AVERAGE_POOL_2D, CircleAveragePool2D)
+CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceND)
+CIRCLE_NODE(BATCH_MATMUL, CircleBatchMatMul)
+CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTM)
+CIRCLE_NODE(CAST, CircleCast)
+CIRCLE_NODE(CEIL, CircleCeil)
+CIRCLE_NODE(CONCATENATION, CircleConcatenation)
+CIRCLE_NODE(CONV_2D, CircleConv2D)
+CIRCLE_NODE(COS, CircleCos)
+CIRCLE_NODE(CUSTOM, CircleCustom)
+CIRCLE_NODE(DEPTH_TO_SPACE, CircleDepthToSpace)
+CIRCLE_NODE(DEPTHWISE_CONV_2D, CircleDepthwiseConv2D)
+CIRCLE_NODE(DEQUANTIZE, CircleDequantize)
+CIRCLE_NODE(DIV, CircleDiv)
+CIRCLE_NODE(ELU, CircleElu)
+CIRCLE_NODE(EQUAL, CircleEqual)
+CIRCLE_NODE(EXP, CircleExp)
+CIRCLE_NODE(EXPAND_DIMS, CircleExpandDims)
+CIRCLE_NODE(FAKE_QUANT, CircleFakeQuant)
+CIRCLE_NODE(FILL, CircleFill)
+CIRCLE_NODE(FLOOR, CircleFloor)
+CIRCLE_NODE(FLOOR_DIV, CircleFloorDiv)
+CIRCLE_NODE(FLOOR_MOD, CircleFloorMod)
+CIRCLE_NODE(FULLY_CONNECTED, CircleFullyConnected)
+CIRCLE_NODE(GATHER, CircleGather)
+CIRCLE_NODE(GATHER_ND, CircleGatherNd)
+CIRCLE_NODE(GREATER, CircleGreater)
+CIRCLE_NODE(GREATER_EQUAL, CircleGreaterEqual)
+CIRCLE_NODE(IF, CircleIf)
+CIRCLE_NODE(L2_NORMALIZATION, CircleL2Normalize)
+CIRCLE_NODE(L2_POOL_2D, CircleL2Pool2D)
+CIRCLE_NODE(LEAKY_RELU, CircleLeakyRelu)
+CIRCLE_NODE(LESS, CircleLess)
+CIRCLE_NODE(LESS_EQUAL, CircleLessEqual)
+CIRCLE_NODE(LOCAL_RESPONSE_NORMALIZATION, CircleLocalResponseNormalization)
+CIRCLE_NODE(LOG, CircleLog)
+CIRCLE_NODE(LOGICAL_AND, CircleLogicalAnd)
+CIRCLE_NODE(LOGICAL_NOT, CircleLogicalNot)
+CIRCLE_NODE(LOGICAL_OR, CircleLogicalOr)
+CIRCLE_NODE(LOGISTIC, CircleLogistic)
+CIRCLE_NODE(LOG_SOFTMAX, CircleLogSoftmax)
+CIRCLE_NODE(MATRIX_DIAG, CircleMatrixDiag)
+CIRCLE_NODE(MAX_POOL_2D, CircleMaxPool2D)
+CIRCLE_NODE(MATRIX_SET_DIAG, CircleMatrixSetDiag)
+CIRCLE_NODE(MAXIMUM, CircleMaximum)
+CIRCLE_NODE(MEAN, CircleMean)
+CIRCLE_NODE(MINIMUM, CircleMinimum)
+CIRCLE_NODE(MIRROR_PAD, CircleMirrorPad)
+CIRCLE_NODE(MUL, CircleMul)
+CIRCLE_NODE(NEG, CircleNeg)
+CIRCLE_NODE(NON_MAX_SUPPRESSION_V4, CircleNonMaxSuppressionV4)
+CIRCLE_NODE(NON_MAX_SUPPRESSION_V5, CircleNonMaxSuppressionV5)
+CIRCLE_NODE(NOT_EQUAL, CircleNotEqual)
+CIRCLE_NODE(ONE_HOT, CircleOneHot)
+CIRCLE_NODE(PACK, CirclePack)
+CIRCLE_NODE(PAD, CirclePad)
+CIRCLE_NODE(PADV2, CirclePadV2)
+CIRCLE_NODE(POW, CirclePow)
+CIRCLE_NODE(PRELU, CirclePRelu)
+CIRCLE_NODE(QUANTIZE, CircleQuantize)
+CIRCLE_NODE(RANGE, CircleRange)
+CIRCLE_NODE(RANK, CircleRank)
+CIRCLE_NODE(REDUCE_ANY, CircleReduceAny)
+CIRCLE_NODE(REDUCE_MAX, CircleReduceMax)
+CIRCLE_NODE(REDUCE_MIN, CircleReduceMin)
+CIRCLE_NODE(REDUCE_PROD, CircleReduceProd)
+CIRCLE_NODE(RELU, CircleRelu)
+CIRCLE_NODE(RELU6, CircleRelu6)
+CIRCLE_NODE(RELU_N1_TO_1, CircleReluN1To1)
+CIRCLE_NODE(RESHAPE, CircleReshape)
+CIRCLE_NODE(RESIZE_BILINEAR, CircleResizeBilinear)
+CIRCLE_NODE(RESIZE_NEAREST_NEIGHBOR, CircleResizeNearestNeighbor)
+CIRCLE_NODE(REVERSE_SEQUENCE, CircleReverseSequence)
+CIRCLE_NODE(REVERSE_V2, CircleReverseV2)
+CIRCLE_NODE(ROUND, CircleRound)
+CIRCLE_NODE(RSQRT, CircleRsqrt)
+CIRCLE_NODE(SCATTER_ND, CircleScatterNd)
+CIRCLE_NODE(SEGMENT_SUM, CircleSegmentSum)
+CIRCLE_NODE(SELECT, CircleSelect)
+CIRCLE_NODE(SELECT_V2, CircleSelectV2)
+CIRCLE_NODE(SHAPE, CircleShape)
+CIRCLE_NODE(SIN, CircleSin)
+CIRCLE_NODE(SLICE, CircleSlice)
+CIRCLE_NODE(SOFTMAX, CircleSoftmax)
+CIRCLE_NODE(SPACE_TO_BATCH_ND, CircleSpaceToBatchND)
+CIRCLE_NODE(SPACE_TO_DEPTH, CircleSpaceToDepth)
+CIRCLE_NODE(SPARSE_TO_DENSE, CircleSparseToDense)
+CIRCLE_NODE(SPLIT, CircleSplit)
+CIRCLE_NODE(SPLIT_V, CircleSplitV)
+CIRCLE_NODE(SQRT, CircleSqrt)
+CIRCLE_NODE(SQUARE, CircleSquare)
+CIRCLE_NODE(SQUARED_DIFFERENCE, CircleSquaredDifference)
+CIRCLE_NODE(SQUEEZE, CircleSqueeze)
+CIRCLE_NODE(STRIDED_SLICE, CircleStridedSlice)
+CIRCLE_NODE(SUB, CircleSub)
+CIRCLE_NODE(SUM, CircleSum)
+CIRCLE_NODE(TANH, CircleTanh)
+CIRCLE_NODE(TILE, CircleTile)
+CIRCLE_NODE(TOPK_V2, CircleTopKV2)
+CIRCLE_NODE(TRANSPOSE, CircleTranspose)
+CIRCLE_NODE(TRANSPOSE_CONV, CircleTransposeConv)
+CIRCLE_NODE(UNIDIRECTIONAL_SEQUENCE_LSTM, CircleUnidirectionalSequenceLSTM)
+CIRCLE_NODE(UNIQUE, CircleUnique)
+CIRCLE_NODE(UNPACK, CircleUnpack)
+CIRCLE_NODE(WHERE, CircleWhere)
+CIRCLE_NODE(WHILE, CircleWhile)
+CIRCLE_NODE(ZEROS_LIKE, CircleZerosLike)
// Circle Only
-CIRCLE_NODE(BCQ_FULLY_CONNECTED, luci::CircleBCQFullyConnected)
-CIRCLE_NODE(BCQ_GATHER, luci::CircleBCQGather)
-CIRCLE_NODE(INSTANCE_NORM, luci::CircleInstanceNorm)
+CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnected)
+CIRCLE_NODE(BCQ_GATHER, CircleBCQGather)
+CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNorm)
// Virtual node(s)
-CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, luci::CircleBidirectionalSequenceLSTMOut)
-CIRCLE_VNODE(CIRCLECONST, luci::CircleConst)
-CIRCLE_VNODE(CIRCLEINPUT, luci::CircleInput)
-CIRCLE_VNODE(CIRCLEOUTPUT, luci::CircleOutput)
-CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, luci::CircleOutputDummy)
-CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, luci::CircleOutputExclude)
-CIRCLE_VNODE(CIRCLECUSTOMOUT, luci::CircleCustomOut)
-CIRCLE_VNODE(CIRCLEIFOUT, luci::CircleIfOut)
-CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, luci::CircleNonMaxSuppressionV4Out)
-CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV5OUT, luci::CircleNonMaxSuppressionV5Out)
-CIRCLE_VNODE(CIRCLESPLITOUT, luci::CircleSplitOut)
-CIRCLE_VNODE(CIRCLESPLITVOUT, luci::CircleSplitVOut)
-CIRCLE_VNODE(CIRCLETOPKV2OUT, luci::CircleTopKV2Out)
-CIRCLE_VNODE(CIRCLEUNIQUEOUT, luci::CircleUniqueOut)
-CIRCLE_VNODE(CIRCLEUNPACKOUT, luci::CircleUnpackOut)
-CIRCLE_VNODE(CIRCLEWHILEOUT, luci::CircleWhileOut)
+CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, CircleBidirectionalSequenceLSTMOut)
+CIRCLE_VNODE(CIRCLECONST, CircleConst)
+CIRCLE_VNODE(CIRCLEINPUT, CircleInput)
+CIRCLE_VNODE(CIRCLEOUTPUT, CircleOutput)
+CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, CircleOutputDummy)
+CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, CircleOutputExclude)
+CIRCLE_VNODE(CIRCLECUSTOMOUT, CircleCustomOut)
+CIRCLE_VNODE(CIRCLEIFOUT, CircleIfOut)
+CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, CircleNonMaxSuppressionV4Out)
+CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV5OUT, CircleNonMaxSuppressionV5Out)
+CIRCLE_VNODE(CIRCLESPLITOUT, CircleSplitOut)
+CIRCLE_VNODE(CIRCLESPLITVOUT, CircleSplitVOut)
+CIRCLE_VNODE(CIRCLETOPKV2OUT, CircleTopKV2Out)
+CIRCLE_VNODE(CIRCLEUNIQUEOUT, CircleUniqueOut)
+CIRCLE_VNODE(CIRCLEUNPACKOUT, CircleUnpackOut)
+CIRCLE_VNODE(CIRCLEWHILEOUT, CircleWhileOut)
#include <loco/IR/Graph.h>
+#include <map>
#include <memory>
#include <vector>
// TODO provide graph accessor with a name
+public:
+ void source_table(const std::map<uint32_t, std::string> &table) { _source_table = table; }
+
+ const std::map<uint32_t, std::string> &source_table(void) const { return _source_table; }
+
private:
std::vector<std::unique_ptr<loco::Graph>> _graphs;
+
+private:
+ /**
+ * @brief Metadata about source table for profiling
+ *
+ * @note Key is ID of node and value is name of node.
+ *
+ * If there was originally imported 'source_table' in circle model,
+ * the table will be stored as it is.
+ * Otherwise, new 'source_table' is created with imported nodes.
+ *
+ * Even if Module has multiple subgraphs, only first subgraph is considered.
+ */
+ std::map<uint32_t, std::string> _source_table;
};
std::unique_ptr<Module> make_module(void);
private:
std::vector<uint8_t> _data;
+ // TODO use _data for STRING and remove _strings
+ std::vector<std::string> _strings; // for STRING type
};
} // namespace luci
CircleCustom(uint32_t arity, uint32_t out)
: VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>(arity), _output_count(out)
{
- // TODO Support when arity is 0
- assert(arity >= 1);
- assert(out > 0);
+ // NOTE Custom can have 0 input or 0 output but not both
+ assert(arity != 0 || out != 0);
}
public:
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCELEQUANTIZE_H__
+#define __LUCI_IR_CIRCELEQUANTIZE_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief QUANTIZE in Circle
+ */
+class CircleQuantize final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::QUANTIZE>>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCELEDUANTIZE_H__
#undef INSTANTIATE
+// CircleConst implementations for loco::DataType::STRING
+
+template <> uint32_t CircleConst::size<loco::DataType::STRING>(void) const
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(_data.size() == 0);
+ return _strings.size();
+}
+
+template <> void CircleConst::size<loco::DataType::STRING>(uint32_t l)
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(_data.size() == 0);
+ _strings.resize(l);
+}
+
+template <> const std::string &CircleConst::at<loco::DataType::STRING>(uint32_t n) const
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(n < _strings.size());
+ return _strings.at(n);
+}
+
+template <> std::string &CircleConst::at<loco::DataType::STRING>(uint32_t n)
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(n < _strings.size());
+ return _strings.at(n);
+}
+
+template <> const std::string &CircleConst::scalar<loco::DataType::STRING>(void) const
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(1 == _strings.size());
+ return _strings.at(0);
+}
+
+template <> std::string &CircleConst::scalar<loco::DataType::STRING>(void)
+{
+ assert(dtype() == loco::DataType::STRING);
+ assert(1 == _strings.size());
+ return _strings.at(0);
+}
+
} // namespace luci
auto const &cs = const_node.scalar<loco::DataType::S32>();
ASSERT_EQ(1, cs);
}
+
+TEST(CircleConstTest, string)
+{
+ luci::CircleConst const_node;
+
+ const_node.dtype(loco::DataType::STRING);
+ const_node.size<loco::DataType::STRING>(1);
+ const_node.at<loco::DataType::STRING>(0) = std::string("Hello");
+
+ ASSERT_EQ(loco::DataType::STRING, const_node.dtype());
+ ASSERT_EQ(1, const_node.size<loco::DataType::STRING>());
+ EXPECT_TRUE(std::string("Hello") == const_node.at<loco::DataType::STRING>(0));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleQuantize.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+TEST(CircleQuantizeTest, constructor)
+{
+ luci::CircleQuantize quant_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), quant_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::QUANTIZE, quant_node.opcode());
+
+ ASSERT_EQ(nullptr, quant_node.input());
+}
+
+TEST(CircleQuantizeTest, common_NEG)
+{
+ luci::CircleQuantize quant_node;
+
+ quant_node.name("name");
+ ASSERT_EQ("name", quant_node.name());
+
+ auto q = std::make_unique<luci::CircleQuantParam>();
+ quant_node.quantparam(std::move(q));
+ ASSERT_NE(nullptr, quant_node.quantparam());
+
+ ASSERT_EQ(luci::ShapeStatus::UNDEFINED, quant_node.shape_status());
+ quant_node.shape_status(luci::ShapeStatus::NOSHAPE);
+ ASSERT_NE(luci::ShapeStatus::UNDEFINED, quant_node.shape_status());
+}
+
+TEST(CircleQuantizeTest, input_NEG)
+{
+ luci::CircleQuantize quant_node;
+ luci::CircleQuantize node;
+
+ quant_node.input(&node);
+ ASSERT_NE(nullptr, quant_node.input());
+
+ quant_node.input(nullptr);
+ ASSERT_EQ(nullptr, quant_node.input());
+}
+
+TEST(CircleQuantizeTest, arity_NEG)
+{
+ luci::CircleQuantize quant_node;
+
+ ASSERT_NO_THROW(quant_node.arg(0));
+ ASSERT_THROW(quant_node.arg(1), std::out_of_range);
+}
+
+TEST(CircleQuantizeTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleQuantize quant_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(quant_node.accept(&tv), std::exception);
+}
+
+TEST(CircleQuantizeTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleQuantize quant_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(quant_node.accept(&tv), std::exception);
+}
target_link_libraries(luci_log PRIVATE nncc_common)
target_link_libraries(luci_log PRIVATE luci_env)
install(TARGETS luci_log DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
target_link_libraries(luci_logex PRIVATE nncc_common)
target_link_libraries(luci_logex PRIVATE pepper_str)
install(TARGETS luci_logex DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
#include <sstream>
#include <vector>
+using namespace luci;
/**
* @brief dump std::vector<int64_t> values to stream
*/
bool build(const loco::Node *, locop::NodeSummary &s) const final;
protected:
-#define CIRCLE_NODE(OPCODE, CLASS) \
- virtual bool summary(const CLASS *, locop::NodeSummary &s) const \
- { \
- s.comments().append("Emitted by Default CircleNodeSummaryBuilder"); \
- s.state(locop::NodeSummary::State::PartiallyKnown); \
- return true; \
- }
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ virtual bool summary(const CLASS *, locop::NodeSummary &) const { return false; }
#define CIRCLE_VNODE CIRCLE_NODE
#include <luci/IR/CircleNodes.lst>
#undef CIRCLE_VNODE
protected:
const locop::SymbolTable *tbl(void) const { return _tbl; }
- // Please do not use _tbl directly and use tbl().
- // This will be changed to private in near future.
-protected:
- const locop::SymbolTable *_tbl;
-};
-
-class CircleNodeSummaryBuilder final : public CircleNodeSummaryBuilderBase
-{
-public:
- CircleNodeSummaryBuilder(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
- {
- // DO NOTHING
- }
-
private:
-#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
- IMPLEMENT(luci::CircleAbs)
- IMPLEMENT(luci::CircleAdd)
- IMPLEMENT(luci::CircleAddN)
- IMPLEMENT(luci::CircleArgMax)
- IMPLEMENT(luci::CircleArgMin)
- IMPLEMENT(luci::CircleAveragePool2D)
- IMPLEMENT(luci::CircleBatchMatMul)
- IMPLEMENT(luci::CircleBatchToSpaceND)
- IMPLEMENT(luci::CircleBidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleCast)
- IMPLEMENT(luci::CircleCeil)
- IMPLEMENT(luci::CircleConcatenation)
- IMPLEMENT(luci::CircleConst)
- IMPLEMENT(luci::CircleConv2D)
- IMPLEMENT(luci::CircleCos)
- IMPLEMENT(luci::CircleCustom)
- IMPLEMENT(luci::CircleDepthToSpace)
- IMPLEMENT(luci::CircleDepthwiseConv2D)
- IMPLEMENT(luci::CircleDequantize)
- IMPLEMENT(luci::CircleDiv)
- IMPLEMENT(luci::CircleElu)
- IMPLEMENT(luci::CircleExp)
- IMPLEMENT(luci::CircleExpandDims)
- IMPLEMENT(luci::CircleFakeQuant)
- IMPLEMENT(luci::CircleFill)
- IMPLEMENT(luci::CircleFloor)
- IMPLEMENT(luci::CircleFloorDiv)
- IMPLEMENT(luci::CircleFloorMod)
- IMPLEMENT(luci::CircleFullyConnected)
- IMPLEMENT(luci::CircleGather)
- IMPLEMENT(luci::CircleGatherNd)
- IMPLEMENT(luci::CircleGreater)
- IMPLEMENT(luci::CircleGreaterEqual)
- IMPLEMENT(luci::CircleIf)
- IMPLEMENT(luci::CircleL2Normalize)
- IMPLEMENT(luci::CircleLeakyRelu)
- IMPLEMENT(luci::CircleLess)
- IMPLEMENT(luci::CircleLessEqual)
- IMPLEMENT(luci::CircleLocalResponseNormalization)
- IMPLEMENT(luci::CircleLog)
- IMPLEMENT(luci::CircleLogicalAnd)
- IMPLEMENT(luci::CircleLogicalNot)
- IMPLEMENT(luci::CircleLogicalOr)
- IMPLEMENT(luci::CircleLogistic)
- IMPLEMENT(luci::CircleLogSoftmax)
- IMPLEMENT(luci::CircleMatrixDiag)
- IMPLEMENT(luci::CircleMatrixSetDiag)
- IMPLEMENT(luci::CircleMaximum)
- IMPLEMENT(luci::CircleMaxPool2D)
- IMPLEMENT(luci::CircleMean)
- IMPLEMENT(luci::CircleMinimum)
- IMPLEMENT(luci::CircleMirrorPad)
- IMPLEMENT(luci::CircleMul)
- IMPLEMENT(luci::CircleNeg)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5)
- IMPLEMENT(luci::CircleNotEqual)
- IMPLEMENT(luci::CircleOneHot)
- IMPLEMENT(luci::CirclePack)
- IMPLEMENT(luci::CirclePad)
- IMPLEMENT(luci::CirclePadV2)
- IMPLEMENT(luci::CirclePow)
- IMPLEMENT(luci::CirclePRelu)
- IMPLEMENT(luci::CircleRange)
- IMPLEMENT(luci::CircleRank)
- IMPLEMENT(luci::CircleReduceAny)
- IMPLEMENT(luci::CircleReduceMax)
- IMPLEMENT(luci::CircleReduceMin)
- IMPLEMENT(luci::CircleReduceProd)
- IMPLEMENT(luci::CircleRelu)
- IMPLEMENT(luci::CircleRelu6)
- IMPLEMENT(luci::CircleReluN1To1)
- IMPLEMENT(luci::CircleReshape)
- IMPLEMENT(luci::CircleResizeBilinear)
- IMPLEMENT(luci::CircleResizeNearestNeighbor)
- IMPLEMENT(luci::CircleReverseSequence)
- IMPLEMENT(luci::CircleReverseV2)
- IMPLEMENT(luci::CircleRound)
- IMPLEMENT(luci::CircleRsqrt)
- IMPLEMENT(luci::CircleScatterNd)
- IMPLEMENT(luci::CircleSegmentSum)
- IMPLEMENT(luci::CircleSelect)
- IMPLEMENT(luci::CircleSelectV2)
- IMPLEMENT(luci::CircleShape)
- IMPLEMENT(luci::CircleSin)
- IMPLEMENT(luci::CircleSlice)
- IMPLEMENT(luci::CircleSoftmax)
- IMPLEMENT(luci::CircleSpaceToBatchND)
- IMPLEMENT(luci::CircleSpaceToDepth)
- IMPLEMENT(luci::CircleSparseToDense)
- IMPLEMENT(luci::CircleSplit)
- IMPLEMENT(luci::CircleSplitV)
- IMPLEMENT(luci::CircleSqrt)
- IMPLEMENT(luci::CircleSquare)
- IMPLEMENT(luci::CircleSquaredDifference)
- IMPLEMENT(luci::CircleSqueeze)
- IMPLEMENT(luci::CircleStridedSlice)
- IMPLEMENT(luci::CircleSub)
- IMPLEMENT(luci::CircleSum)
- IMPLEMENT(luci::CircleTanh)
- IMPLEMENT(luci::CircleTile)
- IMPLEMENT(luci::CircleTopKV2)
- IMPLEMENT(luci::CircleTranspose)
- IMPLEMENT(luci::CircleTransposeConv)
- IMPLEMENT(luci::CircleUnidirectionalSequenceLSTM)
- IMPLEMENT(luci::CircleUnique)
- IMPLEMENT(luci::CircleUnpack)
- IMPLEMENT(luci::CircleWhere)
- IMPLEMENT(luci::CircleWhile)
- IMPLEMENT(luci::CircleZerosLike)
- // Circle Only
- IMPLEMENT(luci::CircleBCQFullyConnected)
- IMPLEMENT(luci::CircleBCQGather)
- IMPLEMENT(luci::CircleInstanceNorm)
- // Virtual nodes
- IMPLEMENT(luci::CircleInput)
- IMPLEMENT(luci::CircleOutput)
- IMPLEMENT(luci::CircleIfOut)
- IMPLEMENT(luci::CircleNonMaxSuppressionV4Out)
- IMPLEMENT(luci::CircleNonMaxSuppressionV5Out)
- IMPLEMENT(luci::CircleSplitOut)
- IMPLEMENT(luci::CircleSplitVOut)
- IMPLEMENT(luci::CircleTopKV2Out)
- IMPLEMENT(luci::CircleUniqueOut)
- IMPLEMENT(luci::CircleUnpackOut)
- IMPLEMENT(luci::CircleWhileOut)
-#undef IMPLEMENT
+ const locop::SymbolTable *_tbl;
};
template <class CIRCLENODE>
return true;
}
+bool summary_node(const locop::SymbolTable *tbl, const luci::CircleL2Pool2D *node,
+ locop::NodeSummary &s)
+{
+ assert(node->fusedActivationFunction() != luci::FusedActFunc::UNDEFINED);
+
+ s.args().append("value", tbl->lookup(node->value()));
+ s.args().append("filter(h,w)", to_str(node->filter()));
+ s.args().append("stride(h,w)", to_str(node->stride()));
+ s.args().append("padding", to_str(node->padding()));
+ s.args().append("fused", to_str(node->fusedActivationFunction()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
bool summary_node(const locop::SymbolTable *tbl, const luci::CircleLeakyRelu *node,
locop::NodeSummary &s)
{
return true;
}
+bool summary_node(const locop::SymbolTable *, const luci::CircleOutputDummy *,
+ locop::NodeSummary &s)
+{
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
+bool summary_node(const locop::SymbolTable *, const luci::CircleOutputExclude *,
+ locop::NodeSummary &s)
+{
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBCQFullyConnected *node,
locop::NodeSummary &s)
{
return true;
}
+// SummaryBuilderLet type
+enum class SB
+{
+ ABC,
+ DEF,
+ GHIJ,
+ KLMN,
+ OPQR,
+ STUV,
+ WXYZ,
+ CIRC, // circle only
+ VIRT, // virtual
+};
+
+template <SB sb> class SummaryBuilderLet;
+
+#define IMPLEMENT(CLASS) bool summary(const CLASS *, locop::NodeSummary &) const final;
+
+template <> class SummaryBuilderLet<SB::ABC> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleAbs)
+ IMPLEMENT(luci::CircleAdd)
+ IMPLEMENT(luci::CircleAddN)
+ IMPLEMENT(luci::CircleArgMax)
+ IMPLEMENT(luci::CircleArgMin)
+ IMPLEMENT(luci::CircleAveragePool2D)
+ IMPLEMENT(luci::CircleBatchMatMul)
+ IMPLEMENT(luci::CircleBatchToSpaceND)
+ IMPLEMENT(luci::CircleBidirectionalSequenceLSTM)
+ IMPLEMENT(luci::CircleCast)
+ IMPLEMENT(luci::CircleCeil)
+ IMPLEMENT(luci::CircleConcatenation)
+ IMPLEMENT(luci::CircleConst)
+ IMPLEMENT(luci::CircleConv2D)
+ IMPLEMENT(luci::CircleCos)
+ IMPLEMENT(luci::CircleCustom)
+};
+
+template <> class SummaryBuilderLet<SB::DEF> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleDepthToSpace)
+ IMPLEMENT(luci::CircleDepthwiseConv2D)
+ IMPLEMENT(luci::CircleDequantize)
+ IMPLEMENT(luci::CircleDiv)
+ IMPLEMENT(luci::CircleElu)
+ IMPLEMENT(luci::CircleEqual)
+ IMPLEMENT(luci::CircleExp)
+ IMPLEMENT(luci::CircleExpandDims)
+ IMPLEMENT(luci::CircleFakeQuant)
+ IMPLEMENT(luci::CircleFill)
+ IMPLEMENT(luci::CircleFloor)
+ IMPLEMENT(luci::CircleFloorDiv)
+ IMPLEMENT(luci::CircleFloorMod)
+ IMPLEMENT(luci::CircleFullyConnected)
+};
+
+template <> class SummaryBuilderLet<SB::GHIJ> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleGather)
+ IMPLEMENT(luci::CircleGatherNd)
+ IMPLEMENT(luci::CircleGreater)
+ IMPLEMENT(luci::CircleGreaterEqual)
+ IMPLEMENT(luci::CircleIf)
+};
+
+template <> class SummaryBuilderLet<SB::KLMN> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleL2Normalize)
+ IMPLEMENT(luci::CircleL2Pool2D)
+ IMPLEMENT(luci::CircleLeakyRelu)
+ IMPLEMENT(luci::CircleLess)
+ IMPLEMENT(luci::CircleLessEqual)
+ IMPLEMENT(luci::CircleLocalResponseNormalization)
+ IMPLEMENT(luci::CircleLog)
+ IMPLEMENT(luci::CircleLogicalAnd)
+ IMPLEMENT(luci::CircleLogicalNot)
+ IMPLEMENT(luci::CircleLogicalOr)
+ IMPLEMENT(luci::CircleLogistic)
+ IMPLEMENT(luci::CircleLogSoftmax)
+ IMPLEMENT(luci::CircleMatrixDiag)
+ IMPLEMENT(luci::CircleMatrixSetDiag)
+ IMPLEMENT(luci::CircleMaximum)
+ IMPLEMENT(luci::CircleMaxPool2D)
+ IMPLEMENT(luci::CircleMean)
+ IMPLEMENT(luci::CircleMinimum)
+ IMPLEMENT(luci::CircleMirrorPad)
+ IMPLEMENT(luci::CircleMul)
+ IMPLEMENT(luci::CircleNeg)
+ IMPLEMENT(luci::CircleNonMaxSuppressionV4)
+ IMPLEMENT(luci::CircleNonMaxSuppressionV5)
+ IMPLEMENT(luci::CircleNotEqual)
+};
+
+template <> class SummaryBuilderLet<SB::OPQR> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleOneHot)
+ IMPLEMENT(luci::CirclePack)
+ IMPLEMENT(luci::CirclePad)
+ IMPLEMENT(luci::CirclePadV2)
+ IMPLEMENT(luci::CirclePow)
+ IMPLEMENT(luci::CirclePRelu)
+ IMPLEMENT(luci::CircleQuantize)
+ IMPLEMENT(luci::CircleRange)
+ IMPLEMENT(luci::CircleRank)
+ IMPLEMENT(luci::CircleReduceAny)
+ IMPLEMENT(luci::CircleReduceMax)
+ IMPLEMENT(luci::CircleReduceMin)
+ IMPLEMENT(luci::CircleReduceProd)
+ IMPLEMENT(luci::CircleRelu)
+ IMPLEMENT(luci::CircleRelu6)
+ IMPLEMENT(luci::CircleReluN1To1)
+ IMPLEMENT(luci::CircleReshape)
+ IMPLEMENT(luci::CircleResizeBilinear)
+ IMPLEMENT(luci::CircleResizeNearestNeighbor)
+ IMPLEMENT(luci::CircleReverseSequence)
+ IMPLEMENT(luci::CircleReverseV2)
+ IMPLEMENT(luci::CircleRound)
+ IMPLEMENT(luci::CircleRsqrt)
+};
+
+template <> class SummaryBuilderLet<SB::STUV> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleScatterNd)
+ IMPLEMENT(luci::CircleSegmentSum)
+ IMPLEMENT(luci::CircleSelect)
+ IMPLEMENT(luci::CircleSelectV2)
+ IMPLEMENT(luci::CircleShape)
+ IMPLEMENT(luci::CircleSin)
+ IMPLEMENT(luci::CircleSlice)
+ IMPLEMENT(luci::CircleSoftmax)
+ IMPLEMENT(luci::CircleSpaceToBatchND)
+ IMPLEMENT(luci::CircleSpaceToDepth)
+ IMPLEMENT(luci::CircleSparseToDense)
+ IMPLEMENT(luci::CircleSplit)
+ IMPLEMENT(luci::CircleSplitV)
+ IMPLEMENT(luci::CircleSqrt)
+ IMPLEMENT(luci::CircleSquare)
+ IMPLEMENT(luci::CircleSquaredDifference)
+ IMPLEMENT(luci::CircleSqueeze)
+ IMPLEMENT(luci::CircleStridedSlice)
+ IMPLEMENT(luci::CircleSub)
+ IMPLEMENT(luci::CircleSum)
+ IMPLEMENT(luci::CircleTanh)
+ IMPLEMENT(luci::CircleTile)
+ IMPLEMENT(luci::CircleTopKV2)
+ IMPLEMENT(luci::CircleTranspose)
+ IMPLEMENT(luci::CircleTransposeConv)
+ IMPLEMENT(luci::CircleUnidirectionalSequenceLSTM)
+ IMPLEMENT(luci::CircleUnique)
+ IMPLEMENT(luci::CircleUnpack)
+};
+
+template <> class SummaryBuilderLet<SB::WXYZ> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleWhere)
+ IMPLEMENT(luci::CircleWhile)
+ IMPLEMENT(luci::CircleZerosLike)
+};
+
+template <> class SummaryBuilderLet<SB::CIRC> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleBCQFullyConnected)
+ IMPLEMENT(luci::CircleBCQGather)
+ IMPLEMENT(luci::CircleInstanceNorm)
+};
+
+template <> class SummaryBuilderLet<SB::VIRT> final : public CircleNodeSummaryBuilderBase
+{
+public:
+ SummaryBuilderLet(const locop::SymbolTable *tbl) : CircleNodeSummaryBuilderBase(tbl)
+ {
+ // DO NOTHING
+ }
+
+private:
+ IMPLEMENT(luci::CircleInput)
+ IMPLEMENT(luci::CircleOutput)
+ IMPLEMENT(luci::CircleCustomOut)
+ IMPLEMENT(luci::CircleIfOut)
+ IMPLEMENT(luci::CircleNonMaxSuppressionV4Out)
+ IMPLEMENT(luci::CircleNonMaxSuppressionV5Out)
+ IMPLEMENT(luci::CircleOutputDummy)
+ IMPLEMENT(luci::CircleOutputExclude)
+ IMPLEMENT(luci::CircleSplitOut)
+ IMPLEMENT(luci::CircleSplitVOut)
+ IMPLEMENT(luci::CircleTopKV2Out)
+ IMPLEMENT(luci::CircleUniqueOut)
+ IMPLEMENT(luci::CircleUnpackOut)
+ IMPLEMENT(luci::CircleWhileOut)
+};
+
+#undef IMPLEMENT
+
bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary &s) const
{
if (node->dialect() != luci::CircleDialect::get())
return ss.str();
};
-#define CIRCLE_NODE(OPCODE, CLASS) \
- if (dynamic_cast<const CLASS *>(node)) \
- { \
- s.opname(circle_opname(node->opnum())); \
- s.comments().append("Mem = " + ptr_to_str(node)); \
- return summary(dynamic_cast<const CLASS *>(node), s); \
+ auto add_comment = [&]() {
+ auto cnode = loco::must_cast<const luci::CircleNode *>(node);
+ s.opname(circle_opname(node->opnum()));
+ s.comments().append("[" + cnode->name() + "] = " + ptr_to_str(node));
+ };
+
+#define CIRCLE_NODE(OPCODE, CLASS) \
+ if (dynamic_cast<const CLASS *>(node)) \
+ { \
+ if (summary(dynamic_cast<const CLASS *>(node), s)) \
+ { \
+ add_comment(); \
+ return true; \
+ } \
}
#define CIRCLE_VNODE CIRCLE_NODE
#include <luci/IR/CircleNodes.lst>
return false;
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleAbs *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAbs *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleAdd *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAdd *node, locop::NodeSummary &s) const
{
return use_xy_act(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleAddN *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAddN *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleArgMax *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMax *node,
+ locop::NodeSummary &s) const
{
return use_ido(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleArgMin *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleArgMin *node,
+ locop::NodeSummary &s) const
{
return use_ido(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleAveragePool2D *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleAveragePool2D *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleBatchMatMul *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchMatMul *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleBatchToSpaceND *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBatchToSpaceND *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleBidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleBidirectionalSequenceLSTM *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleCast *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCast *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleCeil *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCeil *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleConcatenation *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConcatenation *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleConst *, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConst *, locop::NodeSummary &s) const
{
s.state(locop::NodeSummary::State::PartiallyKnown);
return true;
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleConv2D *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleConv2D *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleCos *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCos *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleCustom *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::ABC>::summary(const luci::CircleCustom *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleDepthToSpace *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthToSpace *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleDepthwiseConv2D *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDepthwiseConv2D *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleDequantize *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDequantize *node,
+ locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleDiv *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleDiv *node, locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleElu *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleElu *node, locop::NodeSummary &s) const
{
return use_features(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleExp *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleEqual *node, locop::NodeSummary &s) const
+{
+ return use_xy(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExp *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleExpandDims *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleExpandDims *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFakeQuant *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFakeQuant *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFloorDiv *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorDiv *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFloorMod *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFloorMod *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFullyConnected *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::DEF>::summary(const luci::CircleFullyConnected *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleGather *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGather *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleGatherNd *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGatherNd *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleGreater *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreater *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleGreaterEqual *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleGreaterEqual *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleIf *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::GHIJ>::summary(const luci::CircleIf *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleL2Normalize *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Normalize *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLess *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleL2Pool2D *node,
+ locop::NodeSummary &s) const
+{
+ return summary_node(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLess *node, locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLessEqual *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLessEqual *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLeakyRelu *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLeakyRelu *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLocalResponseNormalization *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLocalResponseNormalization *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLog *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLog *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLogicalAnd *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalAnd *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLogicalNot *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalNot *node,
+ locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLogicalOr *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogicalOr *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLogistic *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogistic *node,
+ locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleLogSoftmax *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleLogSoftmax *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMatrixDiag *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixDiag *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMatrixSetDiag *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMatrixSetDiag *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMaximum *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaximum *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMaxPool2D *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMaxPool2D *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMean *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMean *node, locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMinimum *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMinimum *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMirrorPad *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMirrorPad *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleMul *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleMul *node, locop::NodeSummary &s) const
{
return use_xy_act(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNeg *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNeg *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNonMaxSuppressionV4 *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV4 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNonMaxSuppressionV5 *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNonMaxSuppressionV5 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNotEqual *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::KLMN>::summary(const luci::CircleNotEqual *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleOneHot *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleOneHot *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CirclePack *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePack *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CirclePad *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePad *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CirclePadV2 *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePadV2 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CirclePow *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePow *node, locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CirclePRelu *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CirclePRelu *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRange *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleQuantize *node,
+ locop::NodeSummary &s) const
+{
+ return use_input(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRange *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRank *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRank *node, locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReduceAny *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceAny *node,
+ locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReduceMax *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMax *node,
+ locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReduceMin *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceMin *node,
+ locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReduceProd *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReduceProd *node,
+ locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRelu *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu *node, locop::NodeSummary &s) const
{
return use_features(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRelu6 *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRelu6 *node,
+ locop::NodeSummary &s) const
{
return use_features(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReluN1To1 *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReluN1To1 *node,
+ locop::NodeSummary &s) const
{
return use_features(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReshape *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReshape *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleResizeBilinear *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeBilinear *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleResizeNearestNeighbor *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleResizeNearestNeighbor *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReverseSequence *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseSequence *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleReverseV2 *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleReverseV2 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRound *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRound *node,
+ locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleRsqrt *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::OPQR>::summary(const luci::CircleRsqrt *node,
+ locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleScatterNd *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleScatterNd *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSegmentSum *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSegmentSum *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSelect *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelect *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSelectV2 *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSelectV2 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleShape *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleShape *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSin *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSin *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSlice *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSlice *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSoftmax *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSoftmax *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSpaceToBatchND *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToBatchND *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSpaceToDepth *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSpaceToDepth *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSparseToDense *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSparseToDense *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSplit *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplit *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSplitV *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSplitV *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSqrt *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqrt *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSquare *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquare *node,
+ locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSquaredDifference *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSquaredDifference *node,
+ locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSqueeze *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSqueeze *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleStridedSlice *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleStridedSlice *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSub *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSub *node, locop::NodeSummary &s) const
{
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSum *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleSum *node, locop::NodeSummary &s) const
{
return use_reducer(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTanh *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTanh *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTile *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTile *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTopKV2 *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTopKV2 *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTranspose *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTranspose *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTransposeConv *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleTransposeConv *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleUnidirectionalSequenceLSTM *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnidirectionalSequenceLSTM *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleUnique *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnique *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleUnpack *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::STUV>::summary(const luci::CircleUnpack *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleWhere *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhere *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleWhile *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleWhile *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleZerosLike *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::WXYZ>::summary(const luci::CircleZerosLike *node,
+ locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSplitOut *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQFullyConnected *node,
+ locop::NodeSummary &s) const
{
- return use_input(tbl(), node, s);
+ return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleSplitVOut *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleBCQGather *node,
+ locop::NodeSummary &s) const
{
- return use_input(tbl(), node, s);
+ return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleTopKV2Out *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::CIRC>::summary(const luci::CircleInstanceNorm *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleUniqueOut *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleInput *, locop::NodeSummary &s) const
{
- return summary_node(tbl(), node, s);
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleUnpackOut *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutput *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleIfOut *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleCustomOut *node,
+ locop::NodeSummary &s) const
+{
+ return use_input(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleIfOut *node,
+ locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNonMaxSuppressionV4Out *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV4Out *node,
+ locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleNonMaxSuppressionV5Out *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleNonMaxSuppressionV5Out *node,
+ locop::NodeSummary &s) const
{
return use_input(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleWhileOut *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputDummy *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleInput *, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleOutputExclude *node,
+ locop::NodeSummary &s) const
{
- s.state(locop::NodeSummary::State::Complete);
- return true;
+ return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleOutput *node, locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitOut *node,
+ locop::NodeSummary &s) const
+{
+ return use_input(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleSplitVOut *node,
+ locop::NodeSummary &s) const
+{
+ return use_input(tbl(), node, s);
+}
+
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleTopKV2Out *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleBCQFullyConnected *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUniqueOut *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleBCQGather *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleUnpackOut *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleInstanceNorm *node,
- locop::NodeSummary &s) const
+bool SummaryBuilderLet<SB::VIRT>::summary(const luci::CircleWhileOut *node,
+ locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
}
return true;
}
- if (CircleNodeSummaryBuilder(_tbl).build(node, s))
- {
- return true;
- }
+#define BUILD_GRP(GRP) \
+ do \
+ { \
+ if (SummaryBuilderLet<SB::GRP>(_tbl).build(node, s)) \
+ return true; \
+ } while (false)
+
+ BUILD_GRP(ABC);
+ BUILD_GRP(DEF);
+ BUILD_GRP(GHIJ);
+ BUILD_GRP(KLMN);
+ BUILD_GRP(OPQR);
+ BUILD_GRP(STUV);
+ BUILD_GRP(WXYZ);
+ BUILD_GRP(CIRC);
+ BUILD_GRP(VIRT);
return false;
}
target_link_libraries(luci_partition PRIVATE luci_logex)
target_link_libraries(luci_partition PRIVATE mio_circle)
target_link_libraries(luci_partition PRIVATE nncc_common)
+target_link_libraries(luci_partition PRIVATE pepper_csv2vec)
target_link_libraries(luci_partition PRIVATE oops)
install(TARGETS luci_partition DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
*/
struct PartitionTable
{
+ enum class COMPLY
+ {
+ UNDEFINED,
+ OPCODE,
+ OPNAME,
+ };
+
std::vector<std::string> groups;
std::string default_group;
+ COMPLY comply = COMPLY::UNDEFINED;
// assign by opcode name: OPCODENAME=group
std::unordered_map<std::string /* OPCODENAME */, std::string /* group */> byopcodes;
- // TODO add assign by OP name
+ // assign by op name: OPNAME=group
+ std::unordered_map<std::string /* OPNAME */, std::string /* group */> byopnames;
};
/**
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_DUMP_H__
+#define __LUCI_PARTITION_DUMP_H__
+
+#include "luci/Partition.h"
+
+#include <iostream>
+
+std::ostream &operator<<(std::ostream &os, const luci::PartitionTable &table);
+
+#endif // __LUCI_PARTITION_DUMP_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_VALIDATE_H__
+#define __LUCI_PARTITION_VALIDATE_H__
+
+#include "luci/Partition.h"
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+bool validate(luci::PartitionTable &partition);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITION_VALIDATE_H__
MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); }
MapNode2Clone::iterator end(void) { return node2clone.end(); }
+ MapNode2Clone::const_iterator find(const CircleNode *org) const { return node2clone.find(org); }
+ MapNode2Clone::const_iterator end(void) const { return node2clone.end(); }
+
MapNode2Clone node2clone;
};
ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){};
public:
- // void visit(const luci::CircleAbs *) final;
+ void visit(const luci::CircleAbs *) final;
void visit(const luci::CircleAdd *) final;
- // void visit(const luci::CircleAddN *) final;
- // void visit(const luci::CircleArgMax *) final;
- // void visit(const luci::CircleArgMin *) final;
- // void visit(const luci::CircleAveragePool2D *) final;
- // void visit(const luci::CircleBatchMatMul *) final;
- // void visit(const luci::CircleBatchToSpaceND *) final;
- // void visit(const luci::CircleCast *) final;
- // void visit(const luci::CircleCeil *) final;
- // void visit(const luci::CircleConcatenation *) final;
+ void visit(const luci::CircleAddN *) final;
+ void visit(const luci::CircleArgMax *) final;
+ void visit(const luci::CircleArgMin *) final;
+ void visit(const luci::CircleAveragePool2D *) final;
+ void visit(const luci::CircleBatchMatMul *) final;
+ void visit(const luci::CircleBatchToSpaceND *) final;
+ void visit(const luci::CircleCast *) final;
+ void visit(const luci::CircleCeil *) final;
+ void visit(const luci::CircleConcatenation *) final;
void visit(const luci::CircleConst *) final;
- // void visit(const luci::CircleConv2D *) final;
- // void visit(const luci::CircleCos *) final;
- // void visit(const luci::CircleCustom *) final;
- // void visit(const luci::CircleDepthToSpace *) final;
- // void visit(const luci::CircleDepthwiseConv2D *) final;
- // void visit(const luci::CircleDequantize *) final;
+ void visit(const luci::CircleConv2D *) final;
+ void visit(const luci::CircleCos *) final;
+ void visit(const luci::CircleCustom *) final;
+ void visit(const luci::CircleDepthToSpace *) final;
+ void visit(const luci::CircleDepthwiseConv2D *) final;
+ void visit(const luci::CircleDequantize *) final;
void visit(const luci::CircleDiv *) final;
- // void visit(const luci::CircleElu *) final;
- // void visit(const luci::CircleEqual *) final;
- // void visit(const luci::CircleExp *) final;
- // void visit(const luci::CircleExpandDims *) final;
- // void visit(const luci::CircleFakeQuant *) final;
- // void visit(const luci::CircleFill *) final;
- // void visit(const luci::CircleFloor *) final;
- // void visit(const luci::CircleFloorDiv *) final;
- // void visit(const luci::CircleFloorMod *) final;
- // void visit(const luci::CircleFullyConnected *) final;
- // void visit(const luci::CircleGather *) final;
- // void visit(const luci::CircleGatherNd *) final;
- // void visit(const luci::CircleGreater *) final;
- // void visit(const luci::CircleGreaterEqual *) final;
- // void visit(const luci::CircleIf *) final;
- // void visit(const luci::CircleL2Normalize *) final;
- // void visit(const luci::CircleL2Pool2D *) final;
- // void visit(const luci::CircleLeakyRelu *) final;
- // void visit(const luci::CircleLess *) final;
- // void visit(const luci::CircleLessEqual *) final;
- // void visit(const luci::CircleLocalResponseNormalization *) final;
- // void visit(const luci::CircleLog *) final;
- // void visit(const luci::CircleLogicalAnd *) final;
- // void visit(const luci::CircleLogicalNot *) final;
- // void visit(const luci::CircleLogicalOr *) final;
- // void visit(const luci::CircleLogistic *) final;
- // void visit(const luci::CircleLogSoftmax *) final;
- // void visit(const luci::CircleMatrixDiag *) final;
- // void visit(const luci::CircleMatrixSetDiag *) final;
- // void visit(const luci::CircleMaximum *) final;
- // void visit(const luci::CircleMaxPool2D *) final;
+ void visit(const luci::CircleElu *) final;
+ void visit(const luci::CircleEqual *) final;
+ void visit(const luci::CircleExp *) final;
+ void visit(const luci::CircleExpandDims *) final;
+ void visit(const luci::CircleFakeQuant *) final;
+ void visit(const luci::CircleFill *) final;
+ void visit(const luci::CircleFloor *) final;
+ void visit(const luci::CircleFloorDiv *) final;
+ void visit(const luci::CircleFloorMod *) final;
+ void visit(const luci::CircleFullyConnected *) final;
+ void visit(const luci::CircleGather *) final;
+ void visit(const luci::CircleGatherNd *) final;
+ void visit(const luci::CircleGreater *) final;
+ void visit(const luci::CircleGreaterEqual *) final;
+ void visit(const luci::CircleIf *) final;
+ void visit(const luci::CircleL2Normalize *) final;
+ void visit(const luci::CircleL2Pool2D *) final;
+ void visit(const luci::CircleLeakyRelu *) final;
+ void visit(const luci::CircleLess *) final;
+ void visit(const luci::CircleLessEqual *) final;
+ void visit(const luci::CircleLocalResponseNormalization *) final;
+ void visit(const luci::CircleLog *) final;
+ void visit(const luci::CircleLogicalAnd *) final;
+ void visit(const luci::CircleLogicalNot *) final;
+ void visit(const luci::CircleLogicalOr *) final;
+ void visit(const luci::CircleLogistic *) final;
+ void visit(const luci::CircleLogSoftmax *) final;
+ void visit(const luci::CircleMatrixDiag *) final;
+ void visit(const luci::CircleMatrixSetDiag *) final;
+ void visit(const luci::CircleMaximum *) final;
+ void visit(const luci::CircleMaxPool2D *) final;
void visit(const luci::CircleMean *) final;
- // void visit(const luci::CircleMinimum *) final;
- // void visit(const luci::CircleMirrorPad *) final;
+ void visit(const luci::CircleMinimum *) final;
+ void visit(const luci::CircleMirrorPad *) final;
void visit(const luci::CircleMul *) final;
- // void visit(const luci::CircleNeg *) final;
- // void visit(const luci::CircleNonMaxSuppressionV4 *) final;
- // void visit(const luci::CircleNonMaxSuppressionV5 *) final;
- // void visit(const luci::CircleNotEqual *) final;
- // void visit(const luci::CircleOneHot *) final;
- // void visit(const luci::CirclePack *) final;
- // void visit(const luci::CirclePad *) final;
- // void visit(const luci::CirclePadV2 *) final;
+ void visit(const luci::CircleNeg *) final;
+ void visit(const luci::CircleNonMaxSuppressionV4 *) final;
+ void visit(const luci::CircleNonMaxSuppressionV5 *) final;
+ void visit(const luci::CircleNotEqual *) final;
+ void visit(const luci::CircleOneHot *) final;
+ void visit(const luci::CirclePack *) final;
+ void visit(const luci::CirclePad *) final;
+ void visit(const luci::CirclePadV2 *) final;
void visit(const luci::CirclePow *) final;
- // void visit(const luci::CirclePRelu *) final;
- // void visit(const luci::CircleRange *) final;
- // void visit(const luci::CircleRank *) final;
- // void visit(const luci::CircleReduceAny *) final;
- // void visit(const luci::CircleReduceMax *) final;
- // void visit(const luci::CircleReduceMin *) final;
- // void visit(const luci::CircleReduceProd *) final;
- // void visit(const luci::CircleRelu *) final;
- // void visit(const luci::CircleRelu6 *) final;
- // void visit(const luci::CircleReluN1To1 *) final;
- // void visit(const luci::CircleReshape *) final;
- // void visit(const luci::CircleResizeBilinear *) final;
- // void visit(const luci::CircleResizeNearestNeighbor *) final;
- // void visit(const luci::CircleReverseSequence *) final;
- // void visit(const luci::CircleReverseV2 *) final;
- // void visit(const luci::CircleRound *) final;
+ void visit(const luci::CirclePRelu *) final;
+ void visit(const luci::CircleQuantize *) final;
+ void visit(const luci::CircleRange *) final;
+ void visit(const luci::CircleRank *) final;
+ void visit(const luci::CircleReduceAny *) final;
+ void visit(const luci::CircleReduceMax *) final;
+ void visit(const luci::CircleReduceMin *) final;
+ void visit(const luci::CircleReduceProd *) final;
+ void visit(const luci::CircleRelu *) final;
+ void visit(const luci::CircleRelu6 *) final;
+ void visit(const luci::CircleReluN1To1 *) final;
+ void visit(const luci::CircleReshape *) final;
+ void visit(const luci::CircleResizeBilinear *) final;
+ void visit(const luci::CircleResizeNearestNeighbor *) final;
+ void visit(const luci::CircleReverseSequence *) final;
+ void visit(const luci::CircleReverseV2 *) final;
+ void visit(const luci::CircleRound *) final;
void visit(const luci::CircleRsqrt *) final;
- // void visit(const luci::CircleScatterNd *) final;
- // void visit(const luci::CircleSegmentSum *) final;
- // void visit(const luci::CircleSelect *) final;
- // void visit(const luci::CircleSelectV2 *) final;
- // void visit(const luci::CircleShape *) final;
- // void visit(const luci::CircleSin *) final;
- // void visit(const luci::CircleSlice *) final;
- // void visit(const luci::CircleSoftmax *) final;
- // void visit(const luci::CircleSpaceToBatchND *) final;
- // void visit(const luci::CircleSpaceToDepth *) final;
- // void visit(const luci::CircleSparseToDense *) final;
- // void visit(const luci::CircleSplit *) final;
- // void visit(const luci::CircleSplitV *) final;
+ void visit(const luci::CircleScatterNd *) final;
+ void visit(const luci::CircleSegmentSum *) final;
+ void visit(const luci::CircleSelect *) final;
+ void visit(const luci::CircleSelectV2 *) final;
+ void visit(const luci::CircleShape *) final;
+ void visit(const luci::CircleSin *) final;
+ void visit(const luci::CircleSlice *) final;
+ void visit(const luci::CircleSoftmax *) final;
+ void visit(const luci::CircleSpaceToBatchND *) final;
+ void visit(const luci::CircleSpaceToDepth *) final;
+ void visit(const luci::CircleSparseToDense *) final;
+ void visit(const luci::CircleSplit *) final;
+ void visit(const luci::CircleSplitV *) final;
void visit(const luci::CircleSqrt *) final;
- // void visit(const luci::CircleSquare *) final;
+ void visit(const luci::CircleSquare *) final;
void visit(const luci::CircleSquaredDifference *) final;
- // void visit(const luci::CircleSqueeze *) final;
- // void visit(const luci::CircleStridedSlice *) final;
+ void visit(const luci::CircleSqueeze *) final;
+ void visit(const luci::CircleStridedSlice *) final;
void visit(const luci::CircleSub *) final;
- // void visit(const luci::CircleSum *) final;
- // void visit(const luci::CircleTanh *) final;
- // void visit(const luci::CircleTile *) final;
- // void visit(const luci::CircleTopKV2 *) final;
- // void visit(const luci::CircleTranspose *) final;
- // void visit(const luci::CircleTransposeConv *) final;
- // void visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
- // void visit(const luci::CircleUnique *) final;
- // void visit(const luci::CircleUnpack *) final;
- // void visit(const luci::CircleWhere *) final;
- // void visit(const luci::CircleWhile *) final;
- // void visit(const luci::CircleZerosLike *) final;
+ void visit(const luci::CircleSum *) final;
+ void visit(const luci::CircleTanh *) final;
+ void visit(const luci::CircleTile *) final;
+ void visit(const luci::CircleTopKV2 *) final;
+ void visit(const luci::CircleTranspose *) final;
+ void visit(const luci::CircleTransposeConv *) final;
+ void visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
+ void visit(const luci::CircleUnique *) final;
+ void visit(const luci::CircleUnpack *) final;
+ void visit(const luci::CircleWhere *) final;
+ void visit(const luci::CircleWhile *) final;
+ void visit(const luci::CircleZerosLike *) final;
// Circle Only
- // void visit(const luci::CircleBCQFullyConnected *) final;
- // void visit(const luci::CircleBCQGather *) final;
- // void visit(const luci::CircleInstanceNorm *) final;
+ void visit(const luci::CircleBCQFullyConnected *) final;
+ void visit(const luci::CircleBCQGather *) final;
+ void visit(const luci::CircleInstanceNorm *) final;
+
+ // NOTE CircleInput and CircleOutput are not handled here as these need
+ // link with graph I/O
// Virtual
- // void visit(const luci::CircleCustomOut *) final;
- // void visit(const luci::CircleIfOut *) final;
+ void visit(const luci::CircleCustomOut *) final;
+ void visit(const luci::CircleIfOut *) final;
// void visit(const luci::CircleInput *) final;
- // void visit(const luci::CircleNonMaxSuppressionV4Out *) final;
- // void visit(const luci::CircleNonMaxSuppressionV5Out *) final;
+ void visit(const luci::CircleNonMaxSuppressionV4Out *) final;
+ void visit(const luci::CircleNonMaxSuppressionV5Out *) final;
// void visit(const luci::CircleOutput *) final;
- // void visit(const luci::CircleOutputDummy *) final;
- // void visit(const luci::CircleOutputExclude *) final;
- // void visit(const luci::CircleSplitOut *) final;
- // void visit(const luci::CircleSplitVOut *) final;
- // void visit(const luci::CircleTopKV2Out *) final;
- // void visit(const luci::CircleUniqueOut *) final;
- // void visit(const luci::CircleUnpackOut *) final;
- // void visit(const luci::CircleWhileOut *) final;
+ void visit(const luci::CircleOutputDummy *) final;
+ void visit(const luci::CircleOutputExclude *) final;
+ void visit(const luci::CircleSplitOut *) final;
+ void visit(const luci::CircleSplitVOut *) final;
+ void visit(const luci::CircleTopKV2Out *) final;
+ void visit(const luci::CircleUniqueOut *) final;
+ void visit(const luci::CircleUnpackOut *) final;
+ void visit(const luci::CircleWhileOut *) final;
public:
luci::CircleNode *find_clone(const luci::CircleNode *node);
if (shape_in.size() != N)
throw std::runtime_error("Failed to init TestIsOGraph");
- TestIsGraphlet<N>::init(TestIsGraphlet<N>::g(), shape_in);
- TestOGraphlet::init(TestIsGraphlet<N>::g(), shape_out);
+ auto g = TestIsGraphlet<N>::g();
+ TestIsGraphlet<N>::init(g, shape_in);
+ TestOGraphlet::init(g, shape_out);
}
};
T *_node{nullptr};
};
+template <class T> class NodeIsOsGraphletT
+{
+public:
+ virtual void init(loco::Graph *g, uint32_t n, uint32_t m)
+ {
+ _node = g->nodes()->create<T>(n, m);
+ _node->dtype(loco::DataType::S32);
+ _node->name("node");
+ }
+
+ T *node(void) const { return _node; }
+
+protected:
+ T *_node{nullptr};
+};
+
+template <unsigned N, unsigned M>
+class TestIsOsGraph : public TestIsGraphlet<N>, public TestOsGraphlet<M>
+{
+public:
+ TestIsOsGraph() = default;
+
+public:
+ virtual void init(const std::initializer_list<ShapeU32> shape_in,
+ const std::initializer_list<ShapeU32> shape_out)
+ {
+ if (shape_in.size() != N)
+ throw std::runtime_error("Failed to init TestIsOsGraph");
+ if (shape_out.size() != M)
+ throw std::runtime_error("Failed to init TestIsOsGraph");
+
+ auto g = TestIsGraphlet<N>::g();
+ TestIsGraphlet<N>::init(g, shape_in);
+ TestOsGraphlet<M>::init(g, shape_out);
+ }
+};
+
/**
* @brief ConnectionTestHelper provides common framework for testing
* cloned CircleNode connection
}
}
+ template <unsigned N, unsigned M> void prepare_inputs(TestIsOsGraph<N, M> *isosgraph)
+ {
+ assert(N == isosgraph->num_inputs());
+ assert(M == isosgraph->num_outputs());
+
+ for (uint32_t i = 0; i < N; ++i)
+ {
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isosgraph->input(i), input);
+ _clonectx.emplace(isosgraph->input(i), input);
+ _inputs.push_back(input);
+ }
+ }
+
+ /**
+ * @note although there is only one input, method name has 's' to make test simple
+ */
+ void prepare_inputs(TestIOGraph *isograph)
+ {
+ assert(1 == isograph->num_inputs());
+
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isograph->input(), input);
+ _clonectx.emplace(isograph->input(), input);
+ _inputs.push_back(input);
+ }
+
/**
* @note prepare_inputs_miss is for negative testing
*/
}
}
+ template <unsigned N, unsigned M> void prepare_inputs_miss(TestIsOsGraph<N, M> *isograph)
+ {
+ assert(N == isograph->num_inputs());
+ assert(M == isograph->num_outputs());
+
+ for (uint32_t i = 0; i < N; ++i)
+ {
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isograph->input(i), input);
+ if (i != 0)
+ _clonectx.emplace(isograph->input(i), input);
+ _inputs.push_back(input);
+ }
+ }
+
+ void prepare_inputs_miss(TestIOGraph *isograph)
+ {
+ assert(1 == isograph->num_inputs());
+
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isograph->input(), input);
+ // _clonectx.emplace() is NOT called on purpose
+ _inputs.push_back(input);
+ }
+
void clone_connect(luci::CircleNode *node, luci::CircleNode *clone)
{
_clonectx.emplace(node, clone);
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+/**
+ * @note This method and all other connect() are just to reduce LOC of ConnectNode class
+ */
+void connect(luci::ConnectNode *cn, const luci::CircleAbs *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleAbs *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleAbs *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleAbs>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Abs)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Abs_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleAddN *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleAddN *>(cn->find_clone(node));
+
+ uint32_t num_inputs = cloned->arity();
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->inputs(i));
+
+ cloned->inputs(i, cn->find_clone(input));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleAddN *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeIsGraphletT<luci::CircleAddN>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g(), 3);
+
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ node()->inputs(i, input(i));
+ }
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_AddN)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_AddN_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleArgMax *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleArgMax *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *dimension = loco::must_cast<luci::CircleNode *>(node->dimension());
+
+ cloned->input(cn->find_clone(input));
+ cloned->dimension(cn->find_clone(dimension));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleArgMax *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleArgMax>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->dimension(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ArgMax)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ArgMax_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleArgMin *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleArgMin *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *dimension = loco::must_cast<luci::CircleNode *>(node->dimension());
+
+ cloned->input(cn->find_clone(input));
+ cloned->dimension(cn->find_clone(dimension));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleArgMin *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleArgMin>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->dimension(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ArgMin)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ArgMin_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleAveragePool2D *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleAveragePool2D *>(cn->find_clone(node));
+
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value());
+
+ cloned->value(cn->find_clone(value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleAveragePool2D *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleAveragePool2D>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleAveragePool2D>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->value(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_AveragePool2D)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_AveragePool2D_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleBCQFullyConnected *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleBCQFullyConnected *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *weights_scales = loco::must_cast<luci::CircleNode *>(node->weights_scales());
+ luci::CircleNode *weights_binary = loco::must_cast<luci::CircleNode *>(node->weights_binary());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+ luci::CircleNode *weights_clusters =
+ loco::must_cast<luci::CircleNode *>(node->weights_clusters());
+
+ cloned->input(cn->find_clone(input));
+ cloned->weights_scales(cn->find_clone(weights_scales));
+ cloned->weights_binary(cn->find_clone(weights_binary));
+ cloned->bias(cn->find_clone(bias));
+ cloned->weights_clusters(cn->find_clone(weights_clusters));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleBCQFullyConnected *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleBCQFullyConnected>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleBCQFullyConnected>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->weights_scales(input(1));
+ node()->weights_binary(input(2));
+ node()->bias(input(3));
+ node()->weights_clusters(input(4));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_BCQFullyConnected)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(5, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+}
+
+TEST(ConnectNodeTest, connect_BCQFullyConnected_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleBCQGather *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleBCQGather *>(cn->find_clone(node));
+
+ luci::CircleNode *input_scales = loco::must_cast<luci::CircleNode *>(node->input_scales());
+ luci::CircleNode *input_binary = loco::must_cast<luci::CircleNode *>(node->input_binary());
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+ luci::CircleNode *input_clusters = loco::must_cast<luci::CircleNode *>(node->input_clusters());
+
+ cloned->input_scales(cn->find_clone(input_scales));
+ cloned->input_binary(cn->find_clone(input_binary));
+ cloned->indices(cn->find_clone(indices));
+ cloned->input_clusters(cn->find_clone(input_clusters));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleBCQGather *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleBCQGather>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<4>::init({shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input_scales(input(0));
+ node()->input_binary(input(1));
+ node()->indices(input(2));
+ node()->input_clusters(input(3));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_BCQGather)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(4, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+}
+
+TEST(ConnectNodeTest, connect_BCQGather_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleBatchMatMul *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleBatchMatMul *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleBatchMatMul *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleBatchMatMul>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_BatchMatMul)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_BatchMatMul_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleBatchToSpaceND *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleBatchToSpaceND *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *block_shape = loco::must_cast<luci::CircleNode *>(node->block_shape());
+ luci::CircleNode *crops = loco::must_cast<luci::CircleNode *>(node->crops());
+
+ cloned->input(cn->find_clone(input));
+ cloned->block_shape(cn->find_clone(block_shape));
+ cloned->crops(cn->find_clone(crops));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleBatchToSpaceND *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleBatchToSpaceND>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->block_shape(input(1));
+ node()->crops(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_BatchToSpaceND)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_BatchToSpaceND_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleCast *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleCast *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleCast *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleCast>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Cast)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Cast_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleCeil *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleCeil *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleCeil *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleCeil>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Ceil)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Ceil_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleConcatenation *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleConcatenation *>(cn->find_clone(node));
+
+ uint32_t num_inputs = cloned->numValues();
+ for (uint32_t i = 0; i < num_inputs; ++i)
+ {
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->values(i));
+
+ cloned->values(i, cn->find_clone(value));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleConcatenation *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeIsGraphletT<luci::CircleConcatenation>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, uint32_t n) override
+ {
+ NodeIsGraphletT<luci::CircleConcatenation>::init(g, n);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g(), 3);
+
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ node()->values(i, input(i));
+ }
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Concatenation)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Concatenation_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleConv2D *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleConv2D *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+
+ cloned->input(cn->find_clone(input));
+ cloned->filter(cn->find_clone(filter));
+ cloned->bias(cn->find_clone(bias));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleConv2D *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleConv2D>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleConv2D>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->filter(input(1));
+ node()->bias(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Conv2D)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Conv2D_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleCos *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleCos *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleCos *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleCos>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Cos)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Cos_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleCustom *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleCustom *>(cn->find_clone(node));
+
+ uint32_t numInputs = cloned->numInputs();
+ for (uint32_t i = 0; i < numInputs; ++i)
+ {
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->inputs(i));
+
+ cloned->inputs(i, cn->find_clone(input));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleCustom *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+/**
+ * @note Does not use template like others as only Custom have both multiple in/out
+ */
+class NodeGraphlet
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ virtual void init(loco::Graph *g, uint32_t in, uint32_t out)
+ {
+ _node = g->nodes()->create<luci::CircleCustom>(in, out);
+ _node->dtype(loco::DataType::S32);
+ _node->name("node");
+ }
+
+ luci::CircleCustom *node(void) const { return _node; }
+
+protected:
+ luci::CircleCustom *_node = nullptr;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g(), 3, 3);
+
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ node()->inputs(i, input(i));
+ }
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Custom)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Custom_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleCustomOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleCustomOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleCustomOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleCustomOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_CustomOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_CustomOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleDepthToSpace *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleDepthToSpace *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleDepthToSpace *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleDepthToSpace>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_DepthToSpace)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_DepthToSpace_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleDepthwiseConv2D *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleDepthwiseConv2D *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+
+ cloned->input(cn->find_clone(input));
+ cloned->filter(cn->find_clone(filter));
+ cloned->bias(cn->find_clone(bias));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleDepthwiseConv2D *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleDepthwiseConv2D>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleDepthwiseConv2D>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->filter(input(1));
+ node()->bias(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_DepthwiseConv2D)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_DepthwiseConv2D_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleDequantize *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleDequantize *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleDequantize *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleDequantize>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Dequantize)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Dequantize_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleElu *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleElu *>(cn->find_clone(node));
+
+ luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features());
+
+ cloned->features(cn->find_clone(features));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleElu *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleElu>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->features(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Elu)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Elu_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleEqual *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleEqual *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleEqual *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleEqual>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Equal)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Equal_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleExp *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleExp *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleExp *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleExp>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Exp)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Exp_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleExpandDims *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleExpandDims *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *axis = loco::must_cast<luci::CircleNode *>(node->axis());
+
+ cloned->input(cn->find_clone(input));
+ cloned->axis(cn->find_clone(axis));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleExpandDims *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleExpandDims>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->axis(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ExpandDims)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ExpandDims_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFakeQuant *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFakeQuant *>(cn->find_clone(node));
+
+ luci::CircleNode *inputs = loco::must_cast<luci::CircleNode *>(node->inputs());
+
+ cloned->inputs(cn->find_clone(inputs));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFakeQuant *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFakeQuant>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->inputs(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_FakeQuant)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_FakeQuant_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFill *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFill *>(cn->find_clone(node));
+
+ luci::CircleNode *dims = loco::must_cast<luci::CircleNode *>(node->dims());
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value());
+
+ cloned->dims(cn->find_clone(dims));
+ cloned->value(cn->find_clone(value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFill *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFill>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->dims(input(0));
+ node()->value(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Fill)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Fill_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFloor *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFloor *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFloor *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFloor>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Floor)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Floor_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFloorDiv *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFloorDiv *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFloorDiv *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFloorDiv>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_FloorDiv)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_FloorDiv_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFloorMod *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFloorMod *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFloorMod *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFloorMod>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_FloorMod)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_FloorMod_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleFullyConnected *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleFullyConnected *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *weights = loco::must_cast<luci::CircleNode *>(node->weights());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+
+ cloned->input(cn->find_clone(input));
+ cloned->weights(cn->find_clone(weights));
+ cloned->bias(cn->find_clone(bias));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleFullyConnected *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleFullyConnected>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleFullyConnected>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->weights(input(1));
+ node()->bias(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_FullyConnected)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_FullyConnected_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleGather *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleGather *>(cn->find_clone(node));
+
+ luci::CircleNode *params = loco::must_cast<luci::CircleNode *>(node->params());
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+
+ cloned->params(cn->find_clone(params));
+ cloned->indices(cn->find_clone(indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleGather *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleGather>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->params(input(0));
+ node()->indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Gather)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Gather_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleGatherNd *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleGatherNd *>(cn->find_clone(node));
+
+ luci::CircleNode *params = loco::must_cast<luci::CircleNode *>(node->params());
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+
+ cloned->params(cn->find_clone(params));
+ cloned->indices(cn->find_clone(indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleGatherNd *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleGatherNd>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->params(input(0));
+ node()->indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_GatherNd)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_GatherNd_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleGreater *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleGreater *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleGreater *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleGreater>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Greater)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Greater_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleGreaterEqual *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleGreaterEqual *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleGreaterEqual *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleGreaterEqual>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_GreaterEqual)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_GreaterEqual_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleIf *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleIf *>(cn->find_clone(node));
+
+ luci::CircleNode *cond = loco::must_cast<luci::CircleNode *>(node->cond());
+
+ cloned->cond(cn->find_clone(cond));
+
+ auto input_count = node->input_count();
+ for (uint32_t in = 0; in < input_count; ++in)
+ {
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input(in));
+
+ cloned->input(in, cn->find_clone(input));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleIf *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeIsOsGraphletT<luci::CircleIf>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, uint32_t n, uint32_t m) override
+ {
+ // cond() will take one input
+ NodeIsOsGraphletT::init(g, n - 1, m);
+ }
+};
+
+class TestNodeGraph : public TestIsOsGraph<3, 1>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOsGraph<3, 1>::init({shape, shape, shape}, {shape});
+ NodeGraphlet::init(g(), 3, 1);
+
+ node()->cond(input(0));
+ node()->input(0, input(1));
+ node()->input(1, input(2));
+
+ output(0)->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_If)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs<3, 1>(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ // aritiy(3) = cond + input(2)
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_If_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss<3, 1>(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleIfOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleIfOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleIfOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleIfOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_IfOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_IfOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleInstanceNorm *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleInstanceNorm *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *gamma = loco::must_cast<luci::CircleNode *>(node->gamma());
+ luci::CircleNode *beta = loco::must_cast<luci::CircleNode *>(node->beta());
+
+ cloned->input(cn->find_clone(input));
+ cloned->gamma(cn->find_clone(gamma));
+ cloned->beta(cn->find_clone(beta));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleInstanceNorm *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleInstanceNorm>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleInstanceNorm>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->gamma(input(1));
+ node()->beta(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_InstanceNorm)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_InstanceNorm_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleL2Normalize *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleL2Normalize *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleL2Normalize *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleL2Normalize>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleL2Normalize>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_L2Normalize)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_L2Normalize_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleL2Pool2D *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleL2Pool2D *>(cn->find_clone(node));
+
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value());
+
+ cloned->value(cn->find_clone(value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleL2Pool2D *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleL2Pool2D>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleL2Pool2D>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->value(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_L2Pool2D)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_L2Pool2D_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLeakyRelu *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLeakyRelu *>(cn->find_clone(node));
+
+ luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features());
+
+ cloned->features(cn->find_clone(features));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLeakyRelu *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLeakyRelu>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->features(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LeakyRelu)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_LeakyRelu_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLess *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLess *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLess *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLess>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Less)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Less_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLessEqual *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLessEqual *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLessEqual *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLessEqual>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LessEqual)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_LessEqual_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLocalResponseNormalization *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLocalResponseNormalization *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLocalResponseNormalization *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLocalResponseNormalization>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LocalResponseNormalization)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_LocalResponseNormalization_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLog *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLog *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLog *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLog>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Log)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Log_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLogSoftmax *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLogSoftmax *>(cn->find_clone(node));
+
+ luci::CircleNode *logits = loco::must_cast<luci::CircleNode *>(node->logits());
+
+ cloned->logits(cn->find_clone(logits));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLogSoftmax *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLogSoftmax>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->logits(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LogSoftmax)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_LogSoftmax_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLogicalAnd *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLogicalAnd *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLogicalAnd *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalAnd>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LogicalAnd)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_LogicalAnd_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLogicalNot *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLogicalNot *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLogicalNot *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalNot>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LogicalNot)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_LogicalNot_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLogicalOr *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLogicalOr *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLogicalOr *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalOr>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_LogicalOr)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_LogicalOr_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleLogistic *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleLogistic *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleLogistic *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleLogistic>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Logistic)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Logistic_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMatrixDiag *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMatrixDiag *>(cn->find_clone(node));
+
+ luci::CircleNode *diagonal = loco::must_cast<luci::CircleNode *>(node->diagonal());
+
+ cloned->diagonal(cn->find_clone(diagonal));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMatrixDiag *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMatrixDiag>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->diagonal(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_MatrixDiag)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_MatrixDiag_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMatrixSetDiag *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMatrixSetDiag *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *diagonal = loco::must_cast<luci::CircleNode *>(node->diagonal());
+
+ cloned->input(cn->find_clone(input));
+ cloned->diagonal(cn->find_clone(diagonal));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMatrixSetDiag *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMatrixSetDiag>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->diagonal(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_MatrixSetDiag)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_MatrixSetDiag_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMaxPool2D *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMaxPool2D *>(cn->find_clone(node));
+
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value());
+
+ cloned->value(cn->find_clone(value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMaxPool2D *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMaxPool2D>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleMaxPool2D>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->value(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_MaxPool2D)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_MaxPool2D_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMaximum *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMaximum *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMaximum *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMaximum>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Maximum)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Maximum_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMean>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Mean)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Mean_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMinimum *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMinimum *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMinimum *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMinimum>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Minimum)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Minimum_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMirrorPad *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMirrorPad *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings());
+
+ cloned->input(cn->find_clone(input));
+ cloned->paddings(cn->find_clone(paddings));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMirrorPad *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMirrorPad>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleMirrorPad>::init(g);
+
+ _node->mode(luci::MirrorPadMode::REFLECT);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->paddings(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_MirrorPad)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_MirrorPad_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNeg *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNeg *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNeg *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNeg>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Neg)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Neg_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV4 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(cn->find_clone(node));
+
+ luci::CircleNode *boxes = loco::must_cast<luci::CircleNode *>(node->boxes());
+ luci::CircleNode *scores = loco::must_cast<luci::CircleNode *>(node->scores());
+ luci::CircleNode *max_output_size = loco::must_cast<luci::CircleNode *>(node->max_output_size());
+ luci::CircleNode *iou_threshold = loco::must_cast<luci::CircleNode *>(node->iou_threshold());
+ luci::CircleNode *score_threshold = loco::must_cast<luci::CircleNode *>(node->score_threshold());
+
+ cloned->boxes(cn->find_clone(boxes));
+ cloned->scores(cn->find_clone(scores));
+ cloned->max_output_size(cn->find_clone(max_output_size));
+ cloned->iou_threshold(cn->find_clone(iou_threshold));
+ cloned->score_threshold(cn->find_clone(score_threshold));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNonMaxSuppressionV4 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV4>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->boxes(input(0));
+ node()->scores(input(1));
+ node()->max_output_size(input(2));
+ node()->iou_threshold(input(3));
+ node()->score_threshold(input(4));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV4)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(5, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+}
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV4_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV4Out *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNonMaxSuppressionV4Out *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV4Out>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV4Out)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV4Out_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV5 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(cn->find_clone(node));
+
+ luci::CircleNode *boxes = loco::must_cast<luci::CircleNode *>(node->boxes());
+ luci::CircleNode *scores = loco::must_cast<luci::CircleNode *>(node->scores());
+ luci::CircleNode *max_output_size = loco::must_cast<luci::CircleNode *>(node->max_output_size());
+ luci::CircleNode *iou_threshold = loco::must_cast<luci::CircleNode *>(node->iou_threshold());
+ luci::CircleNode *score_threshold = loco::must_cast<luci::CircleNode *>(node->score_threshold());
+ luci::CircleNode *soft_nms_sigma = loco::must_cast<luci::CircleNode *>(node->soft_nms_sigma());
+
+ cloned->boxes(cn->find_clone(boxes));
+ cloned->scores(cn->find_clone(scores));
+ cloned->max_output_size(cn->find_clone(max_output_size));
+ cloned->iou_threshold(cn->find_clone(iou_threshold));
+ cloned->score_threshold(cn->find_clone(score_threshold));
+ cloned->soft_nms_sigma(cn->find_clone(soft_nms_sigma));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNonMaxSuppressionV5 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV5>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<6>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<6>::init({shape, shape, shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->boxes(input(0));
+ node()->scores(input(1));
+ node()->max_output_size(input(2));
+ node()->iou_threshold(input(3));
+ node()->score_threshold(input(4));
+ node()->soft_nms_sigma(input(5));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV5)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(6, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+ ASSERT_EQ(cth.inputs(4), clone->arg(4));
+ ASSERT_EQ(cth.inputs(5), clone->arg(5));
+}
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV5_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV5Out *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNonMaxSuppressionV5Out *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV5Out>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV5Out)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_NonMaxSuppressionV5Out_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleNotEqual *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleNotEqual *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleNotEqual *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleNotEqual>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_NotEqual)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_NotEqual_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleOneHot *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleOneHot *>(cn->find_clone(node));
+
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+ luci::CircleNode *depth = loco::must_cast<luci::CircleNode *>(node->depth());
+ luci::CircleNode *on_value = loco::must_cast<luci::CircleNode *>(node->on_value());
+ luci::CircleNode *off_value = loco::must_cast<luci::CircleNode *>(node->off_value());
+
+ cloned->indices(cn->find_clone(indices));
+ cloned->depth(cn->find_clone(depth));
+ cloned->on_value(cn->find_clone(on_value));
+ cloned->off_value(cn->find_clone(off_value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleOneHot *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleOneHot>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<4>::init({shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->indices(input(0));
+ node()->depth(input(1));
+ node()->on_value(input(2));
+ node()->off_value(input(3));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_OneHot)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(4, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+}
+
+TEST(ConnectNodeTest, connect_OneHot_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleOutputDummy *)
+{
+ // Nothing to do
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleOutputExclude *)
+{
+ // Nothing to do
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CirclePRelu *node)
+{
+ auto *cloned = loco::must_cast<luci::CirclePRelu *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *alpha = loco::must_cast<luci::CircleNode *>(node->alpha());
+
+ cloned->input(cn->find_clone(input));
+ cloned->alpha(cn->find_clone(alpha));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CirclePRelu *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CirclePRelu>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->alpha(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_PRelu)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_PRelu_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CirclePack *node)
+{
+ auto *cloned = loco::must_cast<luci::CirclePack *>(cn->find_clone(node));
+
+ uint32_t values_count = cloned->values_count();
+ for (uint32_t i = 0; i < values_count; ++i)
+ {
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->values(i));
+
+ cloned->values(i, cn->find_clone(value));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CirclePack *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeIsGraphletT<luci::CirclePack>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g(), 3);
+
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ node()->values(i, input(i));
+ }
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Pack)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Pack_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CirclePad *node)
+{
+ auto *cloned = loco::must_cast<luci::CirclePad *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings());
+
+ cloned->input(cn->find_clone(input));
+ cloned->paddings(cn->find_clone(paddings));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CirclePad *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CirclePad>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->paddings(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Pad)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Pad_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CirclePadV2 *node)
+{
+ auto *cloned = loco::must_cast<luci::CirclePadV2 *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings());
+ luci::CircleNode *constant_values = loco::must_cast<luci::CircleNode *>(node->constant_values());
+
+ cloned->input(cn->find_clone(input));
+ cloned->paddings(cn->find_clone(paddings));
+ cloned->constant_values(cn->find_clone(constant_values));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CirclePadV2 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CirclePadV2>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->paddings(input(1));
+ node()->constant_values(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_PadV2)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_PadV2_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CirclePow>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Pow)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Pow_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleQuantize *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleQuantize *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleQuantize *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleQuantize>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Quantize)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Quantize_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRange *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRange *>(cn->find_clone(node));
+
+ luci::CircleNode *start = loco::must_cast<luci::CircleNode *>(node->start());
+ luci::CircleNode *limit = loco::must_cast<luci::CircleNode *>(node->limit());
+ luci::CircleNode *delta = loco::must_cast<luci::CircleNode *>(node->delta());
+
+ cloned->start(cn->find_clone(start));
+ cloned->limit(cn->find_clone(limit));
+ cloned->delta(cn->find_clone(delta));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRange *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRange>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->start(input(0));
+ node()->limit(input(1));
+ node()->delta(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Range)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Range_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRank *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRank *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRank *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRank>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Rank)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Rank_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReduceAny *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReduceAny *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReduceAny *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReduceAny>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReduceAny)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReduceAny_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReduceMax *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReduceMax *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReduceMax *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReduceMax>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReduceMax)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReduceMax_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReduceMin *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReduceMin *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReduceMin *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReduceMin>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReduceMin)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReduceMin_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReduceProd *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReduceProd *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReduceProd *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReduceProd>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReduceProd)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReduceProd_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRelu *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRelu *>(cn->find_clone(node));
+
+ luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features());
+
+ cloned->features(cn->find_clone(features));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRelu *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRelu>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->features(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Relu)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Relu_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRelu6 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRelu6 *>(cn->find_clone(node));
+
+ luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features());
+
+ cloned->features(cn->find_clone(features));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRelu6 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRelu6>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->features(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Relu6)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Relu6_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReluN1To1 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReluN1To1 *>(cn->find_clone(node));
+
+ luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features());
+
+ cloned->features(cn->find_clone(features));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReluN1To1 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReluN1To1>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->features(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReluN1To1)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_ReluN1To1_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReshape *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReshape *>(cn->find_clone(node));
+
+ luci::CircleNode *tensor = loco::must_cast<luci::CircleNode *>(node->tensor());
+ luci::CircleNode *shape = loco::must_cast<luci::CircleNode *>(node->shape());
+
+ cloned->tensor(cn->find_clone(tensor));
+ cloned->shape(cn->find_clone(shape));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReshape *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReshape>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->tensor(input(0));
+ node()->shape(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Reshape)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Reshape_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleResizeBilinear *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleResizeBilinear *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size());
+
+ cloned->input(cn->find_clone(input));
+ cloned->size(cn->find_clone(size));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleResizeBilinear *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleResizeBilinear>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->size(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ResizeBilinear)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ResizeBilinear_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleResizeNearestNeighbor *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleResizeNearestNeighbor *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size());
+
+ cloned->input(cn->find_clone(input));
+ cloned->size(cn->find_clone(size));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleResizeNearestNeighbor *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleResizeNearestNeighbor>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->size(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ResizeNearestNeighbor)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ResizeNearestNeighbor_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReverseSequence *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReverseSequence *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *seq_lengths = loco::must_cast<luci::CircleNode *>(node->seq_lengths());
+
+ cloned->input(cn->find_clone(input));
+ cloned->seq_lengths(cn->find_clone(seq_lengths));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReverseSequence *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReverseSequence>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->seq_lengths(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReverseSequence)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReverseSequence_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleReverseV2 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleReverseV2 *>(cn->find_clone(node));
+
+ luci::CircleNode *tensor = loco::must_cast<luci::CircleNode *>(node->tensor());
+ luci::CircleNode *axis = loco::must_cast<luci::CircleNode *>(node->axis());
+
+ cloned->tensor(cn->find_clone(tensor));
+ cloned->axis(cn->find_clone(axis));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleReverseV2 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleReverseV2>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->tensor(input(0));
+ node()->axis(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ReverseV2)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_ReverseV2_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRound *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRound *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRound *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRound>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Round)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Round_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleRsqrt>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Rsqrt)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Rsqrt_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleScatterNd *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleScatterNd *>(cn->find_clone(node));
+
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+ luci::CircleNode *updates = loco::must_cast<luci::CircleNode *>(node->updates());
+ luci::CircleNode *shape = loco::must_cast<luci::CircleNode *>(node->shape());
+
+ cloned->indices(cn->find_clone(indices));
+ cloned->updates(cn->find_clone(updates));
+ cloned->shape(cn->find_clone(shape));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleScatterNd *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleScatterNd>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->indices(input(0));
+ node()->updates(input(1));
+ node()->shape(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ScatterNd)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_ScatterNd_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSegmentSum *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSegmentSum *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *segment_ids = loco::must_cast<luci::CircleNode *>(node->segment_ids());
+
+ cloned->input(cn->find_clone(input));
+ cloned->segment_ids(cn->find_clone(segment_ids));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSegmentSum *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSegmentSum>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->segment_ids(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SegmentSum)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_SegmentSum_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSelect *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSelect *>(cn->find_clone(node));
+
+ luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition());
+ luci::CircleNode *t = loco::must_cast<luci::CircleNode *>(node->t());
+ luci::CircleNode *e = loco::must_cast<luci::CircleNode *>(node->e());
+
+ cloned->condition(cn->find_clone(condition));
+ cloned->t(cn->find_clone(t));
+ cloned->e(cn->find_clone(e));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSelect *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSelect>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->condition(input(0));
+ node()->t(input(1));
+ node()->e(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Select)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Select_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSelectV2 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSelectV2 *>(cn->find_clone(node));
+
+ luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition());
+ luci::CircleNode *t = loco::must_cast<luci::CircleNode *>(node->t());
+ luci::CircleNode *e = loco::must_cast<luci::CircleNode *>(node->e());
+
+ cloned->condition(cn->find_clone(condition));
+ cloned->t(cn->find_clone(t));
+ cloned->e(cn->find_clone(e));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSelectV2 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSelectV2>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->condition(input(0));
+ node()->t(input(1));
+ node()->e(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SelectV2)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_SelectV2_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleShape *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleShape *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleShape *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleShape>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Shape)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Shape_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSin *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSin *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSin *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSin>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Sin)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Sin_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSlice *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSlice *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *begin = loco::must_cast<luci::CircleNode *>(node->begin());
+ luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size());
+
+ cloned->input(cn->find_clone(input));
+ cloned->begin(cn->find_clone(begin));
+ cloned->size(cn->find_clone(size));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSlice *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSlice>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->begin(input(1));
+ node()->size(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Slice)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_Slice_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSoftmax *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSoftmax *>(cn->find_clone(node));
+
+ luci::CircleNode *logits = loco::must_cast<luci::CircleNode *>(node->logits());
+
+ cloned->logits(cn->find_clone(logits));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSoftmax *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSoftmax>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->logits(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Softmax)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Softmax_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSpaceToBatchND *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSpaceToBatchND *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *block_shape = loco::must_cast<luci::CircleNode *>(node->block_shape());
+ luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings());
+
+ cloned->input(cn->find_clone(input));
+ cloned->block_shape(cn->find_clone(block_shape));
+ cloned->paddings(cn->find_clone(paddings));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSpaceToBatchND *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSpaceToBatchND>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->block_shape(input(1));
+ node()->paddings(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SpaceToBatchND)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_SpaceToBatchND_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSpaceToDepth *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSpaceToDepth *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSpaceToDepth *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSpaceToDepth>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SpaceToDepth)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_SpaceToDepth_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSparseToDense *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSparseToDense *>(cn->find_clone(node));
+
+ luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices());
+ luci::CircleNode *output_shape = loco::must_cast<luci::CircleNode *>(node->output_shape());
+ luci::CircleNode *values = loco::must_cast<luci::CircleNode *>(node->values());
+ luci::CircleNode *default_value = loco::must_cast<luci::CircleNode *>(node->default_value());
+
+ cloned->indices(cn->find_clone(indices));
+ cloned->output_shape(cn->find_clone(output_shape));
+ cloned->values(cn->find_clone(values));
+ cloned->default_value(cn->find_clone(default_value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSparseToDense *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSparseToDense>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<4>::init({shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->indices(input(0));
+ node()->output_shape(input(1));
+ node()->values(input(2));
+ node()->default_value(input(3));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SparseToDense)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(4, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+}
+
+TEST(ConnectNodeTest, connect_SparseToDense_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSplit *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSplit *>(cn->find_clone(node));
+
+ luci::CircleNode *split_dim = loco::must_cast<luci::CircleNode *>(node->split_dim());
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->split_dim(cn->find_clone(split_dim));
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSplit *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSplit>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->split_dim(input(0));
+ node()->input(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Split)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Split_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSplitOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSplitOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSplitOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSplitOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SplitOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_SplitOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSplitV *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSplitV *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *size_splits = loco::must_cast<luci::CircleNode *>(node->size_splits());
+ luci::CircleNode *split_dim = loco::must_cast<luci::CircleNode *>(node->split_dim());
+
+ cloned->input(cn->find_clone(input));
+ cloned->size_splits(cn->find_clone(size_splits));
+ cloned->split_dim(cn->find_clone(split_dim));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSplitV *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSplitV>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<3>::init({shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->size_splits(input(1));
+ node()->split_dim(input(2));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SplitV)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(3, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+}
+
+TEST(ConnectNodeTest, connect_SplitV_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSplitVOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSplitVOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSplitVOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSplitVOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SplitVOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_SplitVOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSqrt>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Sqrt)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Sqrt_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSquare *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSquare *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSquare *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSquare>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Square)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Square_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSquaredDifference>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_SquaredDifference)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_SquaredDifference_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSqueeze *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSqueeze *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSqueeze *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSqueeze>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Squeeze)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Squeeze_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleStridedSlice *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleStridedSlice *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *begin = loco::must_cast<luci::CircleNode *>(node->begin());
+ luci::CircleNode *end = loco::must_cast<luci::CircleNode *>(node->end());
+ luci::CircleNode *strides = loco::must_cast<luci::CircleNode *>(node->strides());
+
+ cloned->input(cn->find_clone(input));
+ cloned->begin(cn->find_clone(begin));
+ cloned->end(cn->find_clone(end));
+ cloned->strides(cn->find_clone(strides));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleStridedSlice *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleStridedSlice>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<4>::init({shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->begin(input(1));
+ node()->end(input(2));
+ node()->strides(input(3));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_StridedSlice)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(4, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+}
+
+TEST(ConnectNodeTest, connect_StridedSlice_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSum *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSum *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSum *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSum>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->reduction_indices(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Sum)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Sum_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTanh *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTanh *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTanh *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTanh>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Tanh)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Tanh_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTile *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTile *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *multiples = loco::must_cast<luci::CircleNode *>(node->multiples());
+
+ cloned->input(cn->find_clone(input));
+ cloned->multiples(cn->find_clone(multiples));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTile *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTile>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->multiples(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Tile)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Tile_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTopKV2 *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTopKV2 *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *k = loco::must_cast<luci::CircleNode *>(node->k());
+
+ cloned->input(cn->find_clone(input));
+ cloned->k(cn->find_clone(k));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTopKV2 *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTopKV2>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+ node()->k(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_TopKV2)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_TopKV2_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTopKV2Out *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTopKV2Out *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTopKV2Out *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTopKV2Out>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_TopKV2Out)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_TopKV2Out_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTranspose *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTranspose *>(cn->find_clone(node));
+
+ luci::CircleNode *a = loco::must_cast<luci::CircleNode *>(node->a());
+ luci::CircleNode *perm = loco::must_cast<luci::CircleNode *>(node->perm());
+
+ cloned->a(cn->find_clone(a));
+ cloned->perm(cn->find_clone(perm));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTranspose *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTranspose>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->a(input(0));
+ node()->perm(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Transpose)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Transpose_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleTransposeConv *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleTransposeConv *>(cn->find_clone(node));
+
+ luci::CircleNode *inputSizes = loco::must_cast<luci::CircleNode *>(node->inputSizes());
+ luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter());
+ luci::CircleNode *outBackprop = loco::must_cast<luci::CircleNode *>(node->outBackprop());
+ luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias());
+
+ cloned->inputSizes(cn->find_clone(inputSizes));
+ cloned->filter(cn->find_clone(filter));
+ cloned->outBackprop(cn->find_clone(outBackprop));
+ cloned->bias(cn->find_clone(bias));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleTransposeConv *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleTransposeConv>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleTransposeConv>::init(g);
+
+ _node->padding(luci::Padding::VALID);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<4>::init({shape, shape, shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->inputSizes(input(0));
+ node()->filter(input(1));
+ node()->outBackprop(input(2));
+ node()->bias(input(3));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_TransposeConv)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(4, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+ ASSERT_EQ(cth.inputs(2), clone->arg(2));
+ ASSERT_EQ(cth.inputs(3), clone->arg(3));
+}
+
+TEST(ConnectNodeTest, connect_TransposeConv_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleUnidirectionalSequenceLSTM *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ luci::CircleNode *input_to_input_weights =
+ loco::must_cast<luci::CircleNode *>(node->input_to_input_weights());
+ luci::CircleNode *input_to_forget_weights =
+ loco::must_cast<luci::CircleNode *>(node->input_to_forget_weights());
+ luci::CircleNode *input_to_cell_weights =
+ loco::must_cast<luci::CircleNode *>(node->input_to_cell_weights());
+ luci::CircleNode *input_to_output_weights =
+ loco::must_cast<luci::CircleNode *>(node->input_to_output_weights());
+
+ luci::CircleNode *recurrent_to_input_weights =
+ loco::must_cast<luci::CircleNode *>(node->recurrent_to_input_weights());
+ luci::CircleNode *recurrent_to_forget_weights =
+ loco::must_cast<luci::CircleNode *>(node->recurrent_to_forget_weights());
+ luci::CircleNode *recurrent_to_cell_weights =
+ loco::must_cast<luci::CircleNode *>(node->recurrent_to_cell_weights());
+ luci::CircleNode *recurrent_to_output_weights =
+ loco::must_cast<luci::CircleNode *>(node->recurrent_to_output_weights());
+
+ luci::CircleNode *cell_to_input_weights =
+ loco::must_cast<luci::CircleNode *>(node->cell_to_input_weights());
+ luci::CircleNode *cell_to_forget_weights =
+ loco::must_cast<luci::CircleNode *>(node->cell_to_forget_weights());
+ luci::CircleNode *cell_to_output_weights =
+ loco::must_cast<luci::CircleNode *>(node->cell_to_output_weights());
+
+ luci::CircleNode *input_gate_bias = loco::must_cast<luci::CircleNode *>(node->input_gate_bias());
+ luci::CircleNode *forget_gate_bias =
+ loco::must_cast<luci::CircleNode *>(node->forget_gate_bias());
+ luci::CircleNode *cell_gate_bias = loco::must_cast<luci::CircleNode *>(node->cell_gate_bias());
+ luci::CircleNode *output_gate_bias =
+ loco::must_cast<luci::CircleNode *>(node->output_gate_bias());
+
+ luci::CircleNode *projection_weights =
+ loco::must_cast<luci::CircleNode *>(node->projection_weights());
+ luci::CircleNode *projection_bias = loco::must_cast<luci::CircleNode *>(node->projection_bias());
+
+ luci::CircleNode *activation_state =
+ loco::must_cast<luci::CircleNode *>(node->activation_state());
+ luci::CircleNode *cell_state = loco::must_cast<luci::CircleNode *>(node->cell_state());
+
+ luci::CircleNode *input_layer_norm_coefficients =
+ loco::must_cast<luci::CircleNode *>(node->input_layer_norm_coefficients());
+ luci::CircleNode *forget_layer_norm_coefficients =
+ loco::must_cast<luci::CircleNode *>(node->forget_layer_norm_coefficients());
+ luci::CircleNode *cell_layer_norm_coefficients =
+ loco::must_cast<luci::CircleNode *>(node->cell_layer_norm_coefficients());
+ luci::CircleNode *output_layer_norm_coefficients =
+ loco::must_cast<luci::CircleNode *>(node->output_layer_norm_coefficients());
+
+ cloned->input(cn->find_clone(input));
+
+ cloned->input_to_input_weights(cn->find_clone(input_to_input_weights));
+ cloned->input_to_forget_weights(cn->find_clone(input_to_forget_weights));
+ cloned->input_to_cell_weights(cn->find_clone(input_to_cell_weights));
+ cloned->input_to_output_weights(cn->find_clone(input_to_output_weights));
+
+ cloned->recurrent_to_input_weights(cn->find_clone(recurrent_to_input_weights));
+ cloned->recurrent_to_forget_weights(cn->find_clone(recurrent_to_forget_weights));
+ cloned->recurrent_to_cell_weights(cn->find_clone(recurrent_to_cell_weights));
+ cloned->recurrent_to_output_weights(cn->find_clone(recurrent_to_output_weights));
+
+ cloned->cell_to_input_weights(cn->find_clone(cell_to_input_weights));
+ cloned->cell_to_forget_weights(cn->find_clone(cell_to_forget_weights));
+ cloned->cell_to_output_weights(cn->find_clone(cell_to_output_weights));
+
+ cloned->input_gate_bias(cn->find_clone(input_gate_bias));
+ cloned->forget_gate_bias(cn->find_clone(forget_gate_bias));
+ cloned->cell_gate_bias(cn->find_clone(cell_gate_bias));
+ cloned->output_gate_bias(cn->find_clone(output_gate_bias));
+
+ cloned->projection_weights(cn->find_clone(projection_weights));
+ cloned->projection_bias(cn->find_clone(projection_bias));
+
+ cloned->activation_state(cn->find_clone(activation_state));
+ cloned->cell_state(cn->find_clone(cell_state));
+
+ cloned->input_layer_norm_coefficients(cn->find_clone(input_layer_norm_coefficients));
+ cloned->forget_layer_norm_coefficients(cn->find_clone(forget_layer_norm_coefficients));
+ cloned->cell_layer_norm_coefficients(cn->find_clone(cell_layer_norm_coefficients));
+ cloned->output_layer_norm_coefficients(cn->find_clone(output_layer_norm_coefficients));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleUnidirectionalSequenceLSTM *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleUnidirectionalSequenceLSTM>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleUnidirectionalSequenceLSTM>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<24>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<24>::init({shape, shape, shape, shape, shape, shape, shape, shape,
+ shape, shape, shape, shape, shape, shape, shape, shape,
+ shape, shape, shape, shape, shape, shape, shape, shape},
+ shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input(0));
+
+ node()->input_to_input_weights(input(1));
+ node()->input_to_forget_weights(input(2));
+ node()->input_to_cell_weights(input(3));
+ node()->input_to_output_weights(input(4));
+
+ node()->recurrent_to_input_weights(input(5));
+ node()->recurrent_to_forget_weights(input(6));
+ node()->recurrent_to_cell_weights(input(7));
+ node()->recurrent_to_output_weights(input(8));
+
+ node()->cell_to_input_weights(input(9));
+ node()->cell_to_forget_weights(input(10));
+ node()->cell_to_output_weights(input(11));
+
+ node()->input_gate_bias(input(12));
+ node()->forget_gate_bias(input(13));
+ node()->cell_gate_bias(input(14));
+ node()->output_gate_bias(input(15));
+
+ node()->projection_weights(input(16));
+ node()->projection_bias(input(17));
+
+ node()->activation_state(input(18));
+ node()->cell_state(input(19));
+
+ node()->input_layer_norm_coefficients(input(20));
+ node()->forget_layer_norm_coefficients(input(21));
+ node()->cell_layer_norm_coefficients(input(22));
+ node()->output_layer_norm_coefficients(input(23));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_UnidirectionalSequenceLSTM)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(24, clone->arity());
+ // 24 separate checks is too much
+ for (uint32_t i = 0; i < 24; ++i)
+ ASSERT_EQ(cth.inputs(i), clone->arg(i));
+}
+
+TEST(ConnectNodeTest, connect_UnidirectionalSequenceLSTM_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleUnique *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleUnique *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleUnique *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleUnique>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Unique)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Unique_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleUniqueOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleUniqueOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleUniqueOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleUniqueOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_UniqueOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_UniqueOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleUnpack *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleUnpack *>(cn->find_clone(node));
+
+ luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value());
+
+ cloned->value(cn->find_clone(value));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleUnpack *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleUnpack>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->value(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Unpack)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Unpack_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleUnpackOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleUnpackOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleUnpackOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleUnpackOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_UnpackOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_UnpackOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleWhere *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleWhere *>(cn->find_clone(node));
+
+ luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition());
+
+ cloned->condition(cn->find_clone(condition));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleWhere *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleWhere>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->condition(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Where)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_Where_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleWhile *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleWhile *>(cn->find_clone(node));
+
+ auto input_count = node->input_count();
+ for (uint32_t in = 0; in < input_count; ++in)
+ {
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input(in));
+
+ cloned->input(in, cn->find_clone(input));
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleWhile *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeIsOsGraphletT<luci::CircleWhile>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, uint32_t n, uint32_t m) override { NodeIsOsGraphletT::init(g, n, m); }
+};
+
+class TestNodeGraph : public TestIsOsGraph<1, 1>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOsGraph<1, 1>::init({shape}, {shape});
+ NodeGraphlet::init(g(), 1, 1);
+
+ node()->input(0, input(0));
+
+ output(0)->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_While)
+{
+ TestNodeGraph tng;
+ tng.init({1});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs<1, 1>(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_While_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({1});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss<1, 1>(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleWhileOut *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleWhileOut *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleWhileOut *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleWhileOut>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_WhileOut)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_WhileOut_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleZerosLike *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleZerosLike *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+
+ cloned->input(cn->find_clone(input));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleZerosLike *node) { connect(this, node); }
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleZerosLike>
+{
+public:
+ NodeGraphlet() = default;
+};
+
+class TestNodeGraph : public TestIOGraph, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ NodeGraphlet::init(g());
+
+ node()->input(input());
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_ZerosLike)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(1, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+}
+
+TEST(ConnectNodeTest, connect_ZerosLike_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
luci::PartitionTable pt;
pt.default_group = "A";
+ pt.comply = luci::PartitionTable::COMPLY::OPCODE;
auto pms = apply(&module, pt);
LOGGER(l);
- // TODO support multiple subgraph
- assert(source->size() == 1);
-
INFO(l) << "--- Cleanup unused inputs/outputs";
// remove input within same pgroup
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/PartitionDump.h"
+
+namespace
+{
+
+void dump(std::ostream &os, const luci::PartitionTable &table)
+{
+ os << "Backends:";
+ for (auto &group : table.groups)
+ {
+ os << " " << group;
+ if (table.default_group == group)
+ os << "(default)";
+ }
+ os << std::endl;
+
+ os << "Assign by OPCODE: " << std::endl;
+ for (auto &item : table.byopcodes)
+ os << " " << item.first << "=" << item.second << std::endl;
+
+ os << "Assign by OPNAME: " << std::endl;
+ for (auto &item : table.byopnames)
+ os << " " << item.first << "=" << item.second << std::endl;
+}
+
+} // namespace
+
+std::ostream &operator<<(std::ostream &os, const luci::PartitionTable &table)
+{
+ dump(os, table);
+ return os;
+}
return std::move(d_pgroups);
}
-std::string PGroups::group_of(luci::CircleNode *node) const
+GroupKey PGroups::group_of(luci::CircleNode *node) const
{
assert(node != nullptr);
struct PGroup;
+using GroupKey = std::string;
+
/**
* @brief Partition Node with CircleNode with group name
* @note node just points to source luci::CircleNode, NOT the cloned node
struct PNode
{
const luci::CircleNode *node = nullptr;
- std::string group;
+ GroupKey group;
const PGroup *pgroup = nullptr;
};
struct PGroup
{
std::vector<std::unique_ptr<PNode>> pnodes;
- std::string group;
+ GroupKey group;
uint32_t id = 0;
// I/O while partitioning
std::vector<std::unique_ptr<PGroup>> pgroups;
// node2group is to find group key from source node
- std::map<const luci::CircleNode *, std::string> node2group;
+ std::map<const luci::CircleNode *, GroupKey> node2group;
// id2pngroup is to find *pngroup from pngroup id
std::map<uint32_t, PGroup *> id2pgroup;
// default group key for reference
- std::string default_group;
+ GroupKey default_group;
public:
/**
/**
* @brief return group key of node, empty string if not found
*/
- std::string group_of(luci::CircleNode *node) const;
+ GroupKey group_of(luci::CircleNode *node) const;
/**
* @brief return holding pgroup of node, nullptr if not found
std::string group;
for (auto &input : pgroup->inputs)
{
+ // We ignore below logic for CircleConst.
+ // CircleConst will be cloned if they are not found in pgroup as an input.
+ // Refer build_graph(), "add CircleConst for inputs"
+ // Reason: CircleConst can be shared as input to multiple nodes
+ // where each node can be placed in different groups. For this case
+ // we need to clone this CircleConst for each graph of the group.
+ if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
+ continue;
+
auto input_group = pgroups->group_of(input);
// NOTE: all the nodes should be registered and return should be valid group.
- // convert_to_proups() should ensure this.
+ // produce_pgroups() should ensure this, except CircleConst, Input, Outputs.
// assert here to find if there is any problem with this.
assert(not input_group.empty());
if (input_group.empty())
const luci::PartitionTable &partition)
{
assert(source != nullptr);
- // TODO support multiple subgraphs
- assert(source->size() == 1);
+ // NOTE Only main graph (subgraph index 0) will be partitioned.
+ // Other subgraphs will follow the owner (IF/WHILE/...) group
LOGGER(l);
// check if node is normal node that we are interested
if (check_allocate_partition(node))
{
- auto opcodename = luci::opcode_name(node);
- assert(!opcodename.empty());
-
auto group = partition.default_group;
- auto it = partition.byopcodes.find(opcodename);
- if (it != partition.byopcodes.end())
- group = it->second;
+
+ std::string opcodename; // opcodename or opname
+
+ switch (partition.comply)
+ {
+ case luci::PartitionTable::COMPLY::OPCODE:
+ {
+ opcodename = luci::opcode_name(node);
+ assert(!opcodename.empty());
+
+ auto it = partition.byopcodes.find(opcodename);
+ if (it != partition.byopcodes.end())
+ group = it->second;
+ break;
+ }
+ case luci::PartitionTable::COMPLY::OPNAME:
+ {
+ opcodename = node->name();
+ assert(!opcodename.empty());
+
+ auto it = partition.byopnames.find(opcodename);
+ if (it != partition.byopnames.end())
+ group = it->second;
+ break;
+ }
+
+ default:
+ throw std::runtime_error("Unsupported partition.comply");
+ }
INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
<< std::endl;
luci::PartitionTable pt;
pt.default_group = "A";
+ pt.comply = luci::PartitionTable::COMPLY::OPCODE;
auto pgs = produce_pgroups(&module, pt);
namespace
{
+// forward declare
+void clone_ifnode_subgraphs(luci::PartedModule &pm, const luci::CircleIf *if_node,
+ const luci::CloneContext &clonectx);
+void clone_whilenode_subgraphs(luci::PartedModule &pm, const luci::CircleWhile *while_node,
+ const luci::CloneContext &clonectx);
+
void add_graph_input(loco::Graph *graph, luci::CircleInput *input_node)
{
assert(graph != nullptr);
}
/**
+ * @brief make a clone of graph
+ */
+std::unique_ptr<loco::Graph> clone_graph(loco::Graph *graph_org, luci::CloneContext &clonectx)
+{
+ auto graph = loco::make_graph();
+ auto graph_clone = graph.get();
+
+ graph_clone->name(graph_org->name());
+
+ // clone inputs
+ for (uint32_t n = 0; n < graph_org->inputs()->size(); ++n)
+ {
+ auto input_org = luci::input_node(graph_org, n);
+ assert(input_org != nullptr);
+
+ auto *input_clone = graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(input_org, input_clone);
+
+ add_graph_input(graph_clone, input_clone);
+ clonectx.emplace(input_org, input_clone);
+ }
+
+ // clone nodes
+ auto nodes = graph_org->nodes();
+ for (uint32_t n = 0; n < nodes->size(); ++n)
+ {
+ auto node = nodes->at(n);
+
+ // skip for CircleInput, CircleOutput
+ if (dynamic_cast<luci::CircleInput *>(node) != nullptr)
+ continue;
+ if (dynamic_cast<luci::CircleOutput *>(node) != nullptr)
+ continue;
+
+ auto node_org = loco::must_cast<luci::CircleNode *>(node);
+ assert(clonectx.find(node_org) == clonectx.end());
+
+ auto *node_clone = clone_node(node_org, graph_clone);
+ clonectx.emplace(node_org, node_clone);
+ }
+
+ // connect nodes
+ for (uint32_t n = 0; n < nodes->size(); ++n)
+ {
+ auto node = nodes->at(n);
+
+ // skip for CircleInput, CircleOutput
+ if (dynamic_cast<luci::CircleInput *>(node) != nullptr)
+ continue;
+ if (dynamic_cast<luci::CircleOutput *>(node) != nullptr)
+ continue;
+
+ auto node_org = loco::must_cast<luci::CircleNode *>(node);
+ clone_connect(node_org, clonectx);
+ }
+
+ // clone outputs
+ for (uint32_t n = 0; n < graph_org->outputs()->size(); ++n)
+ {
+ auto output_org = luci::output_node(graph_org, n);
+ assert(output_org != nullptr);
+
+ auto *output_clone = graph_clone->nodes()->create<luci::CircleOutput>();
+ luci::copy_common_attributes(output_org, output_clone);
+ // note: we don't add output_clone to clonectx.
+ // logically, output is not used as an input to any other nodes.
+ auto output_from = loco::must_cast<luci::CircleNode *>(output_org->from());
+ auto it = clonectx.find(output_from);
+ assert(it != clonectx.end());
+ output_clone->from(it->second);
+
+ add_graph_output(graph_clone, output_clone);
+ }
+
+ return std::move(graph);
+}
+
+void clone_recursive_subgraphs(luci::PartedModule &pm, loco::Graph *graph,
+ const luci::CloneContext &clonectx)
+{
+ auto nodes = graph->nodes();
+ for (uint32_t n = 0; n < nodes->size(); ++n)
+ {
+ {
+ auto if_node = dynamic_cast<luci::CircleIf *>(nodes->at(n));
+ if (if_node != nullptr)
+ {
+ clone_ifnode_subgraphs(pm, if_node, clonectx);
+ }
+ }
+ {
+ auto while_node = dynamic_cast<luci::CircleWhile *>(nodes->at(n));
+ if (while_node != nullptr)
+ {
+ clone_whilenode_subgraphs(pm, while_node, clonectx);
+ }
+ }
+ // TODO handle others
+ }
+}
+
+void clone_ifnode_subgraphs(luci::PartedModule &pm, const luci::CircleIf *if_node,
+ const luci::CloneContext &clonectx)
+{
+ assert(if_node != nullptr);
+
+ auto it = clonectx.find(if_node);
+ assert(it != clonectx.end());
+ auto if_clone = loco::must_cast<luci::CircleIf *>(it->second);
+
+ luci::CloneContext then_clonectx;
+ luci::CloneContext else_clonectx;
+
+ auto then_graph = if_node->then_graph();
+ auto else_graph = if_node->else_graph();
+
+ auto then_clone = clone_graph(then_graph, then_clonectx);
+ auto else_clone = clone_graph(else_graph, else_clonectx);
+ if_clone->then_graph(then_clone.get());
+ if_clone->else_graph(else_clone.get());
+
+ pm.module->add(std::move(then_clone));
+ int32_t then_index = pm.module->size() - 1;
+ pm.module->add(std::move(else_clone));
+ int32_t else_index = pm.module->size() - 1;
+ if_clone->then_branch(then_index);
+ if_clone->else_branch(else_index);
+
+ // do recursive copy subgraphs of CircleIf if there are any,
+ // inside then_graph or else_graph.
+ clone_recursive_subgraphs(pm, then_graph, then_clonectx);
+ clone_recursive_subgraphs(pm, else_graph, else_clonectx);
+}
+
+void clone_whilenode_subgraphs(luci::PartedModule &pm, const luci::CircleWhile *while_node,
+ const luci::CloneContext &clonectx)
+{
+ assert(while_node != nullptr);
+
+ auto it = clonectx.find(while_node);
+ assert(it != clonectx.end());
+ auto while_clone = loco::must_cast<luci::CircleWhile *>(it->second);
+
+ luci::CloneContext cond_clonectx;
+ luci::CloneContext body_clonectx;
+
+ auto cond_graph = while_node->cond_graph();
+ auto body_graph = while_node->body_graph();
+
+ auto cond_clone = clone_graph(cond_graph, cond_clonectx);
+ auto body_clone = clone_graph(body_graph, body_clonectx);
+ while_clone->cond_graph(cond_clone.get());
+ while_clone->body_graph(body_clone.get());
+
+ pm.module->add(std::move(cond_clone));
+ int32_t cond_index = pm.module->size() - 1;
+ pm.module->add(std::move(body_clone));
+ int32_t body_index = pm.module->size() - 1;
+ while_clone->cond_branch(cond_index);
+ while_clone->body_branch(body_index);
+
+ // do recursive copy subgraphs of CircleWhile if there are any,
+ // inside cond_graph or body_graph.
+ clone_recursive_subgraphs(pm, cond_graph, cond_clonectx);
+ clone_recursive_subgraphs(pm, body_graph, body_clonectx);
+}
+
+/**
* @brief Build loco::graph from pgroup into graph
*/
-void build_graph(loco::Graph *graph, const luci::PGroup *pgroup)
+void build_graph(luci::PartedModule &pm, loco::Graph *graph, const luci::PGroup *pgroup)
{
LOGGER(l);
<< "output(" << output << ") -> " << output_clone << "(" << output_clone->name() << ")"
<< ": from " << it->second << "(" << it->second->name() << ")";
}
+
+ // TODO relocate this if needed
+ // subgraphs for IF/WHILE/... nodes
+ for (auto &pnode : pgroup->pnodes)
+ {
+ {
+ auto if_node = dynamic_cast<const luci::CircleIf *>(pnode->node);
+ if (if_node != nullptr)
+ {
+ clone_ifnode_subgraphs(pm, if_node, clonectx);
+ }
+ }
+ {
+ auto while_node = dynamic_cast<const luci::CircleWhile *>(pnode->node);
+ if (while_node != nullptr)
+ {
+ clone_whilenode_subgraphs(pm, while_node, clonectx);
+ }
+ }
+ // TODO handle others
+ }
}
std::string make_name(const luci::PGroup *pgroup)
pm.module = std::make_unique<luci::Module>();
pm.group = pgroup->group;
+ // the main graph for this module
auto graph = loco::make_graph();
+ auto graph_ptr = graph.get();
auto graph_name = make_name(pgroup.get());
graph->name(graph_name);
+ // Add main graph so that other subgraphs can be added inside build_graph
+ pm.module->add(std::move(graph));
+
INFO(l) << "--- Partition Graph build----------------------";
INFO(l) << "--- name: " << graph_name;
- build_graph(graph.get(), pgroup.get());
+ build_graph(pm, graph_ptr, pgroup.get());
- pm.module->add(std::move(graph));
pms.pmodules.emplace_back(std::move(pm));
}
luci::PartitionTable pt;
pt.default_group = "A";
+ pt.comply = luci::PartitionTable::COMPLY::OPCODE;
auto pgs = produce_pgroups(&module, pt);
auto pms = produce_pmodules(pgs.get());
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/PartitionValidate.h"
+
+#include <luci/Service/Validate.h>
+
+#include <pepper/csv2vec.h>
+
+#include <iostream>
+
+namespace luci
+{
+
+bool validate(luci::PartitionTable &partition)
+{
+ if (partition.groups.size() == 0)
+ {
+ std::cerr << "There is no 'backends' information";
+ return false;
+ }
+ if (partition.default_group.empty())
+ {
+ std::cerr << "There is no 'default' backend information";
+ return false;
+ }
+ if (!pepper::is_one_of<std::string>(partition.default_group, partition.groups))
+ {
+ std::cerr << "'default' backend is not one of 'backends' item";
+ return false;
+ }
+ for (auto &byopcode : partition.byopcodes)
+ {
+ if (!pepper::is_one_of<std::string>(byopcode.second, partition.groups))
+ {
+ std::cerr << "OPCODE " << byopcode.first << " is not assigned to one of 'backends' items";
+ return false;
+ }
+ }
+ for (auto &byopname : partition.byopnames)
+ {
+ if (!pepper::is_one_of<std::string>(byopname.second, partition.groups))
+ {
+ std::cerr << "OPNAME " << byopname.first << " is not assigned to one of 'backends' items";
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace luci
target_link_libraries(luci_pass PRIVATE luci_logex)
target_link_libraries(luci_pass PRIVATE luci_profile)
target_link_libraries(luci_pass PRIVATE nncc_common)
+target_link_libraries(luci_pass PRIVATE pepper_csv2vec)
target_link_libraries(luci_pass PRIVATE oops)
install(TARGETS luci_pass DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
FuseBatchNormWithTConv,
FuseBCQ,
FuseInstanceNorm,
+ FuseMeanWithMean,
+ FuseTransposeWithMean,
ResolveCustomOpAdd,
ResolveCustomOpBatchMatMul,
ResolveCustomOpMatMul,
+ ResolveCustomOpMaxPoolWithArgmax,
QuantizeDequantizeWeights,
QuantizeWithMinMax,
Requantize,
ShuffleWeightTo16x1Float32,
RemoveRedundantTranspose,
ReplaceMulAddWithDepthwiseConv,
+ ReplaceSubWithAdd,
SubstitutePackToReshape,
+ SubstitutePadV2ToPad,
SubstituteSqueezeToReshape,
ConvertNCHWToNHWC,
RemoveUnnecessarySlice,
RemoveUnnecessarySplit,
RemoveUnnecessaryReshape,
TransformMinMaxToRelu6Pass,
+ TransformMinReluToRelu6Pass,
+ SubstituteStridedSliceToReshape,
SubstituteTransposeToReshape,
RemoveRedundantReshape,
+ RemoveFakeQuant,
+ RemoveQuantDequantSeq,
};
enum AlgorithmParameters
Sparsify_block_map,
// convert NCHW to NHWC
- NCHW_to_NHWC_preserve_input_shape,
- NCHW_to_NHWC_preserve_output_shape,
+ NCHW_to_NHWC_input_shape,
+ NCHW_to_NHWC_output_shape,
};
virtual ~Options() = default;
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_MEAN_WITH_MEAN_PASS_H__
+#define __LUCI_FUSE_MEAN_WITH_MEAN_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse two Mean operations follow one by one into one Mean
+ * with merge reduction indices
+ */
+struct FuseMeanWithMeanPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseMeanWithMeanPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_MEAN_WITH_MEAN_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_TRANSPOSE_WITH_MEAN_PASS_H__
+#define __LUCI_FUSE_TRANSPOSE_WITH_MEAN_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse Mean operation with a preceding Transpose
+ */
+struct FuseTransposeWithMeanPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseTransposeWithMeanPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_TRANSPOSE_WITH_MEAN_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_FAKEQUANT_PASS_H__
+#define __LUCI_REMOVE_FAKEQUANT_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Remove FakeQuant node.
+ */
+struct RemoveFakeQuantPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveFakeQuantPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_FAKEQUANT_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_QUANTDEQUANTSEQ_PASS_H__
+#define __LUCI_REMOVE_QUANTDEQUANTSEQ_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Remove Quantize-Dequantize sequence.
+ */
+struct RemoveQuantDequantSeqPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveQuantDequantSeqPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_QUANTDEQUANTSEQ_PASS_H__
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* limitations under the License.
*/
-#ifndef __CIRCLE2CIRCLE_MODEL_H__
-#define __CIRCLE2CIRCLE_MODEL_H__
+#ifndef __LUCI_REPLACE_SUB_WITH_ADD_PASS_H__
+#define __LUCI_REPLACE_SUB_WITH_ADD_PASS_H__
-#include <mio/circle/schema_generated.h>
-
-#include <memory>
+#include <logo/Pass.h>
namespace luci
{
-struct Model
-{
- virtual ~Model() = default;
-
- virtual const ::circle::Model *model(void) = 0;
-};
-
/**
- * @brief Load Circle model (as a raw Model) from a given path
+ * @brief Class to Replace Sub With Add
*
- * @note May return a nullptr
*/
-std::unique_ptr<Model> load_model(const std::string &path);
+struct ReplaceSubWithAddPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ReplaceSubWithAddPass"; }
+
+ bool run(loco::Graph *g) final;
+};
} // namespace luci
-#endif // __CIRCLE2CIRCLE_MODEL_H__
+#endif // __LUCI_REPLACE_SUB_WITH_ADD_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_RESOLVE_CUSTOM_OP_MAXPOOL_WITH_ARGMAX_PASS_H__
+#define __LUCI_RESOLVE_CUSTOM_OP_MAXPOOL_WITH_ARGMAX_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to resolve custom op MaxPoolWithArgmax to subgraph with circle's MaxPool and ArgMax.
+ */
+struct ResolveCustomOpMaxPoolWithArgmaxPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ResolveCustomOpMaxPoolWithArgmaxPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_RESOLVE_CUSTOM_OP_MAXPOOL_WITH_ARGMAX_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_PADV2_TO_PAD_PASS_H__
+#define __LUCI_SUBSTITUTE_PADV2_TO_PAD_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to substitute PadV2 in certain condition to Pad.
+ */
+struct SubstitutePadV2ToPadPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstitutePadV2ToPadPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_PADV2_TO_PAD_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_STRIDED_SLICE_TO_RESHAPE_PASS_H__
+#define __LUCI_SUBSTITUTE_STRIDED_SLICE_TO_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to substitute Strided_Slice with certain condition to single reshape node.
+ */
+struct SubstituteStridedSliceToReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstituteStridedSliceToReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_STRIDED_SLICE_TO_RESHAPE_PASS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_TRANSFORM_MIN_RELU_TO_RELU6_PASS_H__
+#define __LUCI_TRANSFORM_MIN_RELU_TO_RELU6_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to transform Relu(Minimum(input, 6)) to Relu6
+ */
+struct TransformMinReluToRelu6Pass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::TransformMinReluToRelu6Pass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_TRANSFORM_MIN_RELU_TO_RELU6_PASS_H__
#include "luci/Pass/FuseBatchNormWithTConvPass.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
+#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
+#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
#include "luci/Pass/PropagateQuantParamPass.h"
+#include "luci/Pass/RemoveFakeQuantPass.h"
+#include "luci/Pass/RemoveQuantDequantSeqPass.h"
#include "luci/Pass/RemoveRedundantReshapePass.h"
#include "luci/Pass/RemoveRedundantTransposePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
#include "luci/Pass/RemoveUnnecessarySplitPass.h"
#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+#include "luci/Pass/ReplaceSubWithAddPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
+#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
#include "luci/Pass/RequantizePass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
+#include "luci/Pass/SubstitutePadV2ToPadPass.h"
#include "luci/Pass/SubstituteSqueezeToReshapePass.h"
+#include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
#include "luci/Pass/SubstituteTransposeToReshapePass.h"
#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
+#include "luci/Pass/TransformMinReluToRelu6Pass.h"
// TODO add more passes
#include "luci/Pass/CircleShapeInferencePass.h"
#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
+#include <pepper/csv2vec.h>
#include <memory>
#include <sstream>
return true;
}
+void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output)
+{
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ phase.emplace_back(
+ std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
} // namespace
namespace luci
{
logo::Phase phase;
+ // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
+ if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
+ {
+ bool preserve_input =
+ _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
+ bool preserve_output =
+ _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
+
+ convert_nchw_to_nhwc(g, preserve_input, preserve_output);
+ }
+
/* TRANSFORM DECLARATION BEGIN */
phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
{
phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
}
+ if (_options->query(Options::Algorithm::FuseMeanWithMean))
+ {
+ phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
+ }
+ if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
+ {
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
+ }
if (_options->query(Options::Algorithm::FuseInstanceNorm))
{
phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
{
phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
}
+ if (_options->query(Options::Algorithm::FuseTransposeWithMean))
+ {
+ phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
+ }
if (_options->query(Options::Algorithm::FoldAddV2))
{
phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
{
phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
}
+ if (_options->query(Options::Algorithm::RemoveFakeQuant))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
+ }
if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
{
phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
{
phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
}
+ if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
+ {
+ phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
+ }
if (_options->query(Options::Algorithm::SubstitutePackToReshape))
{
phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
}
+ if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
+ }
if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
{
phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
}
+ if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
+ }
if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
{
phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
{
phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
}
- if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
+ if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
{
- bool preserve_input =
- _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_input_shape) == "true";
- bool preserve_output =
- _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_output_shape) == "true";
-
- phase.emplace_back(
- std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
+ phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
}
/* TRANSFORM DECLARATION END */
std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
// traversal order
- std::vector<int32_t> traversal_order = csv_to_vector<int32_t>(str_tarversal_order);
+ std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
// format
std::vector<DimensionType> format;
std::istringstream is(str_format);
is.ignore();
}
// block size
- std::vector<int32_t> block_size = csv_to_vector<int32_t>(str_block_size);
+ std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
// block map
- std::vector<int32_t> block_map = csv_to_vector<int32_t>(str_block_map);
+ std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
block_map};
options->enable(Algorithms::RemoveUnnecessarySlice);
options->enable(Algorithms::RemoveUnnecessarySplit);
options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
+ options->enable(Algorithms::SubstituteStridedSliceToReshape);
options->enable(Algorithms::SubstituteTransposeToReshape);
options->enable(Algorithms::ConvertNCHWToNHWC);
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Log.h>
+#include <functional>
+
namespace
{
bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
+bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> indices)
+{
+ assert(indices.size() == 4);
+
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ return false;
+
+ if (not trans->perm())
+ return false;
+
+ auto perm = dynamic_cast<luci::CircleConst *>(trans->perm());
+ // Only const perm is supported
+ if (not perm)
+ return false;
+
+ if (perm->dtype() != loco::DataType::S32)
+ return false;
+
+ if (perm->size<loco::DataType::S32>() != 4)
+ return false;
+
+ for (uint32_t i = 0; i < 4; i++)
+ {
+ if (perm->at<loco::DataType::S32>(i) != indices[i])
+ return false;
+ }
+
+ return true;
+}
+
luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
const std::vector<int32_t> indices)
{
return trans;
}
+luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
+ const std::vector<int32_t> indices)
+{
+ auto name = node->name();
+ assert(name.length() > 0);
+
+ auto perm = node->graph()->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(indices.size());
+ perm->rank(1);
+ perm->dim(0) = indices.size();
+ for (uint32_t i = 0; i < indices.size(); i++)
+ perm->at<loco::DataType::S32>(i) = indices[i];
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto make_string = [](const std::vector<int32_t> &nums) {
+ std::string str;
+ for (auto num : nums)
+ {
+ if (str.length() > 0)
+ str += ".";
+ str += std::to_string(num);
+ }
+ return str;
+ };
+
+ auto str_indices = make_string(indices);
+
+ perm->name(name + "/Transpose_" + str_indices + "/perm");
+
+ auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
+ trans->perm(perm);
+ trans->name(name + "/Transpose_" + str_indices);
+ luci::add_origin(trans, luci::get_origin(node));
+
+ return trans;
+}
+
int32_t nchw_axis_to_nhwc(int32_t axis)
{
uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
return create_4d_transpose(node, {0, 2, 3, 1});
}
+bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
+{
+ assert(indices.size() == 4); // FIX_CALLER_UNLESS
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(node);
+ if (not reshape)
+ return false;
+
+ if (reshape->rank() != 4)
+ return false;
+
+ auto input = loco::must_cast<luci::CircleNode *>(reshape->tensor());
+ if (input->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+
+ if (reshape->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+
+ if (!(input->dim(0) == reshape->dim(indices[0])) ||
+ !(input->dim(1) == reshape->dim(indices[1])) ||
+ !(input->dim(2) == reshape->dim(indices[2])) || !(input->dim(3) == reshape->dim(indices[3])))
+ return false;
+
+ return true;
+}
+
+// Check if Reshape that converts NCHW -> NHWC
+bool is_pre_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 3, 1, 2}); }
+
+// Check if Reshape that converts NHWC -> NCHW
+bool is_post_reshape(loco::Node *node) { return check_4d_reshape(node, {0, 2, 3, 1}); }
+
+bool is_post_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 3, 1, 2}); }
+
+bool is_pre_transpose(loco::Node *node) { return check_4d_transpose(node, {0, 2, 3, 1}); }
+
uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
{
return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
return nhwc_paddings;
}
+luci::CircleConst *create_NHWC_rindices(luci::CircleConst *rindices)
+{
+ assert(rindices != nullptr); // FIX_CALLER_UNLESS
+
+ if (rindices->dtype() != loco::DataType::S32)
+ return nullptr;
+
+ auto nhwc_rindices = luci::clone(rindices);
+ auto name = rindices->name();
+ assert(name.length() > 0); // FIX_CALLER_UNLESS
+ nhwc_rindices->name(name + "_NHWC");
+
+ auto size = nhwc_rindices->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < size; i++)
+ {
+ nhwc_rindices->at<loco::DataType::S32>(i) =
+ nchw_axis_to_nhwc(rindices->at<loco::DataType::S32>(i));
+ }
+
+ return nhwc_rindices;
+}
+
luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
{
LOGGER(l);
return true;
}
+// NOTE Copied from is_NCHW(CirclePad)
+bool is_NCHW(const luci::CirclePadV2 *node)
+{
+ const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
+ // Non-const paddings is not supported
+ if (paddings == nullptr)
+ return false;
+
+ if (paddings->rank() != 2)
+ return false;
+
+ if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
+ return false;
+
+ // Only check the first two dimensions
+ for (uint32_t dim = 0; dim < 2; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
+ if (data != 0)
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// NOTE Following conditions can be extended later
+// NOTE Used for Maximum, Miminum as ReLU/ReLU6
+//
+// Find T with an NCHW pattern described below
+// - Input (non-constant) shape : [N, C, H, W]
+// - Input (constant) shape : [1] or []
+// - Output shape : [N, C, H, W]
+template <class T>
+bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node,
+ luci::CircleConst *&comp_const)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(node->x());
+ auto y = dynamic_cast<luci::CircleConst *>(node->y());
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->y());
+ comp_const = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ comp_const = y;
+ }
+ else
+ {
+ // Ignore if T does not have a comp_const input.
+ return false;
+ }
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ // Check if scalar
+ const auto const_rank = comp_const->rank();
+ if (const_rank == 0 || (const_rank == 1 && comp_const->dim(0).value() == 1))
+ return true;
+ return false;
+}
+
// NOTE Following conditions can be extended later
//
// Find MUL with an NCHW pattern described below
// - Input (non-constant) shape : [N, C, H, W]
-// - Input (constant) shape : [1, C, 1, 1]
+// - Input (constant) shape : [1, C, 1, 1] or a scalar (1)
// - Output shape : [N, C, H, W]
bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
luci::CircleConst *&multiplier)
return false;
const auto const_rank = multiplier->rank();
- if (const_rank != 4)
+ // Support Rank 4 or scalar (rank 0 or 1)
+ if (const_rank != 4 && const_rank != 0 && const_rank != 1)
return false;
- for (uint32_t i = 0; i < const_rank; i++)
+ if (const_rank == 4)
{
- if (i != 1 && multiplier->dim(i).value() != 1)
- return false;
+ for (uint32_t i = 0; i < const_rank; i++)
+ {
+ if (i != 1 && multiplier->dim(i).value() != 1)
+ return false;
+ }
}
- const auto const_cdim = multiplier->dim(1);
const auto input_cdim = pred_node->dim(1);
const auto output_cdim = node->dim(1);
- if (const_cdim == input_cdim && input_cdim == output_cdim)
+ if (const_rank == 4)
+ {
+ const auto const_cdim = multiplier->dim(1);
+ // Check Input, Output, Const have the same channel size
+ if (const_cdim == input_cdim && input_cdim == output_cdim)
+ return true;
+ else
+ return false;
+ }
+ if (input_cdim == output_cdim)
return true;
else
return false;
// We assume ADD with const input is NCHW if,
// Input shape: (N, C, H, W)
// Output shape: (N, C, H, W)
-// 1. Const shape is (1, C, 1, 1)
+// 1. Const shape is (1, C, 1, 1) or a scalar (1)
// 2. Input, Output, Const have the same C.
bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
luci::CircleConst *&beta)
return false;
const auto const_rank = beta->rank();
- if (const_rank != 4)
+ // Support Rank 4 or scalar (rank 0 or 1)
+ if (const_rank != 4 && const_rank != 0 && const_rank != 1)
return false;
- // Check the shape is (1, C, 1, 1)
- for (uint32_t i = 0; i < const_rank; i++)
+ if (const_rank == 4)
{
- if (i == 1)
- continue;
+ // Check the shape is (1, C, 1, 1)
+ for (uint32_t i = 0; i < const_rank; i++)
+ {
+ if (i == 1)
+ continue;
- if (beta->dim(i).value() != 1)
+ if (beta->dim(i).value() != 1)
+ return false;
+ }
+ }
+
+ const auto input_cdim = pred_node->dim(1);
+ const auto output_cdim = node->dim(1);
+
+ if (const_rank == 4)
+ {
+ const auto const_cdim = beta->dim(1);
+ // Check Input, Output, Const have the same channel size
+ if (const_cdim == input_cdim && input_cdim == output_cdim)
+ return true;
+ else
return false;
}
+ if (input_cdim == output_cdim)
+ return true;
+ else
+ return false;
+}
+
+// We assume SUB with const input is NCHW if,
+// Input shape: (N, C, H, W)
+// Output shape: (N, C, H, W)
+// 1. Const shape is (1, C, 1, 1) or a scalar (1)
+// 2. Input, Output, Const have the same C.
+bool is_NCHW_with_const(const luci::CircleSub *node, const luci::CircleNode *pred_node,
+ const luci::CircleConst *subtract)
+{
+ assert(pred_node != nullptr);
+ assert(subtract != nullptr);
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ const auto const_rank = subtract->rank();
+ // Support Rank 4 or scalar (rank 0 or 1)
+ if (const_rank != 4 && const_rank != 0 && const_rank != 1)
+ return false;
+
+ if (const_rank == 4)
+ {
+ // Check the shape is (1, C, 1, 1)
+ for (uint32_t i = 0; i < const_rank; i++)
+ {
+ if (i == 1)
+ continue;
+
+ if (subtract->dim(i).value() != 1)
+ return false;
+ }
+ }
- const auto const_cdim = beta->dim(1);
const auto input_cdim = pred_node->dim(1);
const auto output_cdim = node->dim(1);
- // Check Input, Output, Const have the same channel size
- if (const_cdim == input_cdim && input_cdim == output_cdim)
+ if (const_rank == 4)
+ {
+ const auto const_cdim = subtract->dim(1);
+ // Check Input, Output, Const have the same channel size
+ if (const_cdim == input_cdim && input_cdim == output_cdim)
+ return true;
+ else
+ return false;
+ }
+ if (input_cdim == output_cdim)
return true;
else
return false;
return true;
}
+template <class T> bool convert_unary_x(T *node)
+{
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->x(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+}
+
class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
{
// Default
auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
- auto nhwc_const = create_NHWC_from_NCHW(beta);
- if (nhwc_const == nullptr)
- return false;
+ if (beta->rank() == 4)
+ {
+ auto nhwc_const = create_NHWC_from_NCHW(beta);
+ if (nhwc_const == nullptr)
+ return false;
+ node->y(nhwc_const);
+ }
node->x(pre_trans);
- node->y(nhwc_const);
}
else if (beta == nullptr)
{
return convert_unary_features<luci::CircleLeakyRelu>(node);
}
+ bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
+
+ bool visit(luci::CircleMaximum *node)
+ {
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleConst *comp_constant = nullptr;
+
+ if (is_NCHW_with_s_const<luci::CircleMaximum>(node, pred_node, comp_constant))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->x(pre_trans);
+ }
+ else
+ {
+ // TODO support other cases
+ return false;
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
+
+ bool visit(luci::CircleMean *node)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ if (input->rank() != 4)
+ return false;
+
+ auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
+ if (not rindices)
+ return false;
+
+ auto nhwc_rindices = create_NHWC_rindices(rindices);
+ if (not nhwc_rindices)
+ return false;
+
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(input);
+ node->input(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ node->reduction_indices(nhwc_rindices);
+
+ if (node->keep_dims())
+ {
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ // node->keep_dims() == false
+ // 1D output never needs a transpose
+ if (node->rank() <= 1)
+ return true;
+
+ std::vector<bool> reduced_dims_nhwc(4, false);
+ uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
+
+ for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
+ {
+ reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
+ }
+
+ // if channel dimension has been reduced, we don't need a transpose
+ if (reduced_dims_nhwc[3])
+ return true;
+
+ // likewise, if both space dimensions are reduced, no transpose is needed
+ if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
+ return true;
+
+ std::vector<int32_t> post_trans_ind;
+ // case 1: only N is reduced
+ if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
+ post_trans_ind = {2, 0, 1};
+
+ // case 2: only H or W is reduced
+ if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
+ post_trans_ind = {0, 2, 1};
+
+ // case 3: N and either H or W are reduced
+ if (num_reduced_indices == 2)
+ post_trans_ind = {1, 0};
+
+ auto post_trans = create_Nd_transpose(node, post_trans_ind);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ bool visit(luci::CircleMinimum *node)
+ {
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleConst *comp_constant = nullptr;
+
+ if (is_NCHW_with_s_const<luci::CircleMinimum>(node, pred_node, comp_constant))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->x(pre_trans);
+ }
+ else
+ {
+ // TODO support other cases
+ return false;
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
+
bool visit(luci::CircleMul *node)
{
LOGGER(l);
pre_trans->a(pred_node);
node->x(pre_trans);
- auto nhwc_const = create_NHWC_from_NCHW(multiplier);
- node->y(nhwc_const);
+ if (multiplier->rank() == 4)
+ {
+ auto nhwc_const = create_NHWC_from_NCHW(multiplier);
+ node->y(nhwc_const);
+ }
}
else if (multiplier == nullptr)
{
- // TODO : Implement this case.
- INFO(l) << "Not yet implemented. Both inputs of MUL are non-const." << std::endl;
- return false;
+ // Only support for input rank 4
+ auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
+ if (input_x->rank() != 4)
+ return false;
+ auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
+ if (input_y->rank() != 4)
+ return false;
+
+ auto pre_trans_x = create_pre_transpose(node);
+ pre_trans_x->a(input_x);
+ node->x(pre_trans_x);
+
+ auto pre_trans_y = create_pre_transpose(node);
+ pre_trans_y->a(input_y);
+ node->y(pre_trans_y);
}
else
{
return true;
}
- bool visit(luci::CircleNeg *node)
+ bool visit(luci::CircleNeg *node) { return convert_unary_x<luci::CircleNeg>(node); }
+
+ bool visit(luci::CirclePad *node)
{
- const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ if (!is_NCHW(node))
+ return false;
+
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
- node->x(pre_trans);
+ node->input(pre_trans);
+
+ auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
+ const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
+ node->paddings(nhwc_paddings);
// Do shape inference for this node again.
node->shape_status(luci::ShapeStatus::UNDEFINED);
return true;
}
- bool visit(luci::CirclePad *node)
+ bool visit(luci::CirclePadV2 *node)
{
if (!is_NCHW(node))
return false;
bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
+
+ bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
+
+ bool visit(luci::CircleSquaredDifference *node)
+ {
+ // TODO support CircleConst input
+ if (dynamic_cast<luci::CircleConst *>(node->x()) != nullptr)
+ return false;
+ if (dynamic_cast<luci::CircleConst *>(node->y()) != nullptr)
+ return false;
+
+ auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
+ if (input_x->rank() != 4)
+ return false;
+ auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
+ if (input_y->rank() != 4)
+ return false;
+
+ auto pre_trans_x = create_pre_transpose(node);
+ pre_trans_x->a(input_x);
+ node->x(pre_trans_x);
+
+ auto pre_trans_y = create_pre_transpose(node);
+ pre_trans_y->a(input_y);
+ node->y(pre_trans_y);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
+
+ bool visit(luci::CircleSub *node)
+ {
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleConst *subtract = nullptr;
+
+ auto const_x = dynamic_cast<luci::CircleConst *>(node->x());
+ auto const_y = dynamic_cast<luci::CircleConst *>(node->y());
+
+ if (const_x != nullptr && const_y == nullptr)
+ {
+ // case of subtract - pred_node
+ pred_node = loco::must_cast<luci::CircleNode *>(node->y());
+ subtract = const_x;
+
+ if (!is_NCHW_with_const(node, pred_node, subtract))
+ return false;
+
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+
+ if (subtract->rank() == 4)
+ {
+ auto nhwc_const = create_NHWC_from_NCHW(subtract);
+ if (nhwc_const == nullptr)
+ return false;
+ node->x(nhwc_const);
+ }
+ node->y(pre_trans);
+ }
+ else if (const_x == nullptr && const_y != nullptr)
+ {
+ // case of pred_node - subtract
+ pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ subtract = const_y;
+
+ if (!is_NCHW_with_const(node, pred_node, subtract))
+ return false;
+
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+
+ if (subtract->rank() == 4)
+ {
+ auto nhwc_const = create_NHWC_from_NCHW(subtract);
+ if (nhwc_const == nullptr)
+ return false;
+ node->y(nhwc_const);
+ }
+
+ node->x(pre_trans);
+ }
+ else if (const_x == nullptr && const_y == nullptr)
+ {
+ // Both inputs are not constant.
+ // In this case, we cannot distinguish NCHW from NHWC,
+ // so just insert Transpose Ops.
+ // Only support for input rank 4.
+ auto input_x = loco::must_cast<luci::CircleNode *>(node->x());
+ if (input_x->rank() != 4)
+ return false;
+ auto input_y = loco::must_cast<luci::CircleNode *>(node->y());
+ if (input_y->rank() != 4)
+ return false;
+
+ auto pre_trans_x = create_pre_transpose(node);
+ pre_trans_x->a(input_x);
+ node->x(pre_trans_x);
+
+ auto pre_trans_y = create_pre_transpose(node);
+ pre_trans_y->a(input_y);
+ node->y(pre_trans_y);
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
};
} // namespace
LOGGER(l);
INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
+ // Annotate NHWC operators
+ // NHWC operators are detected by pattern matching
+ //
+ // Pattern
+ // pre-Transose (or pre-Reshape) + [intermediate Ops] + post-Transpose (or post-Reshape)
+ //
+ // [intermediate Ops] are annotated as NHWC
+ //
+ // NOTE A single pre-Transpose/Reshape can have multiple post-Transpose/Reshape.
+ // For example,
+ // pre-Transpose --- [intermediate Ops] --- post-Transpose
+ // |
+ // +--[intermediate Ops] --- post-Transpose
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (has_data_format(node))
+ continue;
+
+ if (is_pre_transpose(node) || is_pre_reshape(node))
+ {
+ // For recursive call of lambda
+ std::function<void(loco::Node *)> set_data_format_to_succs;
+ set_data_format_to_succs = [&](loco::Node *n) {
+ for (auto succ : loco::succs(n))
+ {
+ // Exit condition
+ if (is_post_transpose(succ) || is_post_reshape(succ))
+ continue;
+
+ if (not has_data_format(succ))
+ {
+ set_data_format(succ, DataFormat::NHWC);
+ }
+
+ set_data_format_to_succs(succ);
+ }
+ };
+
+ set_data_format_to_succs(node);
+ }
+ }
+
// Annotate NCHW operators
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
case luci::CircleOpcode::LEAKY_RELU:
+ case luci::CircleOpcode::LOGISTIC:
+ case luci::CircleOpcode::MAXIMUM:
+ case luci::CircleOpcode::MEAN:
+ case luci::CircleOpcode::MINIMUM:
case luci::CircleOpcode::MUL:
case luci::CircleOpcode::NEG:
case luci::CircleOpcode::PAD:
+ case luci::CircleOpcode::PADV2:
case luci::CircleOpcode::RELU:
case luci::CircleOpcode::RELU6:
+ case luci::CircleOpcode::RSQRT:
+ case luci::CircleOpcode::SQUARED_DIFFERENCE:
+ case luci::CircleOpcode::SUB:
if (!has_data_format(node))
{
set_data_format(node, DataFormat::NCHW);
ConvertNCHWToNHWC converter;
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
if (circle_node->rank() != 4)
- continue;
+ {
+ // TODO replace the check above with the input rank check, and remove the condition below
+ if (not dynamic_cast<luci::CircleMean *>(node))
+ continue;
+ }
if (circle_node->accept(&converter))
{
luci::CircleConst *beta = nullptr;
};
+class NHWCReluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ post_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+ post_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+ post_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+ post_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ post_shape->size<loco::DataType::S32>(4);
+ post_shape->at<loco::DataType::S32>(0) = 1;
+ post_shape->at<loco::DataType::S32>(1) = channel_size;
+ post_shape->at<loco::DataType::S32>(2) = 4;
+ post_shape->at<loco::DataType::S32>(3) = 4;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+
+ relu->features(pre_reshape);
+
+ post_reshape->tensor(relu);
+ post_reshape->shape(post_shape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+ post_reshape->name("post-reshape");
+
+ return post_reshape;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleReshape *post_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+ luci::CircleConst *post_shape = nullptr;
+};
+
+class AddScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ add = g.nodes()->create<luci::CircleAdd>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ add->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ add->shape({1, channel_size, 4, 4});
+ beta->shape({1});
+
+ beta->size<loco::DataType::FLOAT32>(1);
+ beta->at<loco::DataType::FLOAT32>(0) = 3.14;
+
+ add->x(input);
+ add->y(beta);
+
+ add->name("add");
+ beta->name("beta");
+
+ return add;
+ }
+
+public:
+ luci::CircleAdd *add = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
class ConcatenationGraph final : public SimpleGraph
{
protected:
luci::CircleLeakyRelu *leakyrelu = nullptr;
};
+class LogisticGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ logistic = g.nodes()->create<luci::CircleLogistic>();
+ logistic->x(input);
+ logistic->name("logistic");
+
+ return logistic;
+ }
+
+public:
+ luci::CircleLogistic *logistic = nullptr;
+};
+
+class MaximumGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ max = g.nodes()->create<luci::CircleMaximum>();
+ limit = g.nodes()->create<luci::CircleConst>();
+
+ max->dtype(loco::DataType::FLOAT32);
+ limit->dtype(loco::DataType::FLOAT32);
+
+ max->shape({1, 16, 4, 4});
+ limit->shape({});
+
+ limit->size<loco::DataType::FLOAT32>(1);
+ limit->at<loco::DataType::FLOAT32>(0) = 100;
+
+ max->x(input);
+ max->y(limit);
+
+ max->name("max");
+ limit->name("limit");
+
+ return max;
+ }
+
+public:
+ luci::CircleMaximum *max = nullptr;
+ luci::CircleConst *limit = nullptr;
+};
+
+class MeanGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mean = g.nodes()->create<luci::CircleMean>();
+ rindices = g.nodes()->create<luci::CircleConst>();
+
+ mean->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ mean->shape(_shape);
+ rindices->shape({static_cast<uint32_t>(_axes.size())});
+
+ rindices->size<loco::DataType::S32>(_axes.size());
+ for (uint32_t i = 0; i < _axes.size(); ++i)
+ {
+ rindices->at<loco::DataType::S32>(i) = _axes[i];
+ }
+
+ mean->input(input);
+ mean->reduction_indices(rindices);
+ mean->keep_dims(_keep_dims);
+
+ mean->name("mean");
+ rindices->name("rindices");
+
+ return mean;
+ }
+
+public:
+ void keep_dims(bool val) { _keep_dims = val; }
+ void axes(std::vector<int32_t> val) { _axes = val; }
+ void shape(std::initializer_list<uint32_t> val) { _shape = val; }
+
+public:
+ luci::CircleMean *mean = nullptr;
+ luci::CircleConst *rindices = nullptr;
+
+private:
+ bool _keep_dims = true;
+ std::vector<int32_t> _axes = {2, 3};
+ std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+};
+
+class MinimumGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ min = g.nodes()->create<luci::CircleMinimum>();
+ limit = g.nodes()->create<luci::CircleConst>();
+
+ min->dtype(loco::DataType::FLOAT32);
+ limit->dtype(loco::DataType::FLOAT32);
+
+ min->shape({1, 16, 4, 4});
+ limit->shape({});
+
+ limit->size<loco::DataType::FLOAT32>(1);
+ limit->at<loco::DataType::FLOAT32>(0) = 100;
+
+ min->x(input);
+ min->y(limit);
+
+ min->name("min");
+ limit->name("limit");
+
+ return min;
+ }
+
+public:
+ luci::CircleMinimum *min = nullptr;
+ luci::CircleConst *limit = nullptr;
+};
+
class MulGraph final : public SimpleGraph
{
protected:
luci::CircleConst *multiplier = nullptr;
};
+class MulScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mul = g.nodes()->create<luci::CircleMul>();
+ multiplier = g.nodes()->create<luci::CircleConst>();
+
+ mul->dtype(loco::DataType::FLOAT32);
+ multiplier->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ mul->shape({1, channel_size, 4, 4});
+ multiplier->shape({1});
+
+ multiplier->size<loco::DataType::FLOAT32>(1);
+ multiplier->at<loco::DataType::FLOAT32>(0) = 2;
+
+ mul->x(input);
+ mul->y(multiplier);
+
+ mul->name("mul");
+ multiplier->name("multiplier");
+
+ return mul;
+ }
+
+public:
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *multiplier = nullptr;
+};
+
+class MulBothNormGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mul = g.nodes()->create<luci::CircleMul>();
+
+ mul->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ mul->shape({1, channel_size, 4, 4});
+
+ mul->x(input);
+ mul->y(input);
+
+ mul->name("mul");
+
+ return mul;
+ }
+
+public:
+ luci::CircleMul *mul = nullptr;
+};
+
class NegGraph final : public SimpleGraph
{
protected:
luci::CircleConst *paddings = nullptr;
};
+class PadV2Graph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ pad = g.nodes()->create<luci::CirclePadV2>();
+ paddings = g.nodes()->create<luci::CircleConst>();
+ const_value = g.nodes()->create<luci::CircleConst>();
+
+ pad->dtype(loco::DataType::FLOAT32);
+ paddings->dtype(loco::DataType::S32);
+ const_value->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ pad->shape({1, channel_size, 4, 4});
+ paddings->shape({4, 2});
+ const_value->shape({1});
+
+ // paddings data (NCHW)
+ // [[0,0], [0,0], [1,1], [2,2]]
+ paddings->size<loco::DataType::S32>(8);
+ for (uint32_t dim = 0; dim < 4; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ int32_t data = 0;
+
+ if (dim == 2)
+ data = 1;
+ else if (dim == 3)
+ data = 2;
+
+ paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
+ }
+ }
+
+ const_value->size<loco::DataType::FLOAT32>(1);
+ const_value->at<loco::DataType::FLOAT32>(0) = -3.4;
+
+ pad->input(input);
+ pad->paddings(paddings);
+ pad->constant_values(paddings);
+
+ pad->name("padV2");
+ paddings->name("paddings");
+ const_value->name("constant_values");
+
+ return pad;
+ }
+
+public:
+ luci::CirclePadV2 *pad = nullptr;
+ luci::CircleConst *paddings = nullptr;
+ luci::CircleConst *const_value = nullptr;
+};
+
class ReluGraph final : public SimpleGraph
{
protected:
luci::CircleRelu6 *relu6 = nullptr;
};
+class RsqrtGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ rsqrt = g.nodes()->create<luci::CircleRsqrt>();
+ rsqrt->x(input);
+ rsqrt->name("rsqrt");
+
+ return rsqrt;
+ }
+
+public:
+ luci::CircleRsqrt *rsqrt = nullptr;
+};
+
+class SquaredDifferenceGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sqdiff = g.nodes()->create<luci::CircleSquaredDifference>();
+ sqdiff->x(input);
+ sqdiff->y(input);
+ sqdiff->name("sqdiff");
+
+ return sqdiff;
+ }
+
+public:
+ luci::CircleSquaredDifference *sqdiff = nullptr;
+};
+
+class SubGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sub = g.nodes()->create<luci::CircleSub>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ sub->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ sub->shape({1, channel_size, 4, 4});
+ beta->shape({1, channel_size, 1, 1});
+
+ beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ sub->x(input);
+ sub->y(beta);
+
+ sub->name("sub");
+ beta->name("beta");
+
+ return sub;
+ }
+
+public:
+ luci::CircleSub *sub = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
+class SubScalarGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ sub = g.nodes()->create<luci::CircleSub>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ sub->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ sub->shape({1, channel_size, 4, 4});
+ beta->shape({1});
+
+ beta->size<loco::DataType::FLOAT32>(1);
+ beta->at<loco::DataType::FLOAT32>(0) = 5;
+
+ sub->x(beta);
+ sub->y(input);
+
+ sub->name("sub");
+ beta->name("beta");
+
+ return sub;
+ }
+
+public:
+ luci::CircleSub *sub = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
void check_pre_trans(loco::Node *node)
{
auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, NHWC_Relu)
+{
+ // Relu is already NHWC, so it should not be converted
+ // i.e., the graph is not changed
+ NHWCReluGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ EXPECT_EQ(g.pre_reshape, g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ EXPECT_EQ(g.post_reshape, *relu_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, AddScalar)
+{
+ AddScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.add->x());
+
+ auto add_succs = loco::succs(g.add);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(1, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Concatenation)
{
ConcatenationGraph g;
EXPECT_EQ(16, g.leakyrelu->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, Logistic)
+{
+ LogisticGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.logistic->x());
+
+ auto logistic_succs = loco::succs(g.logistic);
+ EXPECT_EQ(1, logistic_succs.size());
+ check_post_trans(*logistic_succs.begin());
+
+ // Check logistic shape
+ EXPECT_EQ(1, g.logistic->dim(0).value());
+ EXPECT_EQ(4, g.logistic->dim(1).value());
+ EXPECT_EQ(4, g.logistic->dim(2).value());
+ EXPECT_EQ(16, g.logistic->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, Maximum)
+{
+ MaximumGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.max->x());
+
+ auto max_succs = loco::succs(g.max);
+ EXPECT_EQ(1, max_succs.size());
+ check_post_trans(*max_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, Mean)
+{
+ MeanGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.mean->input());
+
+ auto mean_succs = loco::succs(g.mean);
+ EXPECT_EQ(1, mean_succs.size());
+ check_post_trans(*mean_succs.begin());
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, Mean_keep_dims_false)
+{
+ struct TC
+ {
+ std::vector<int32_t> nchw_ind;
+ std::vector<int32_t> nhwc_ind;
+ std::initializer_list<uint32_t> shape;
+ bool needs_transpose = false;
+ };
+
+ uint32_t n = 1;
+ uint32_t c = 16;
+ uint32_t h = 4;
+ uint32_t w = 4;
+
+ std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
+ {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
+ {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
+ {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
+ {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
+ {{0, 1, 2}, {0, 3, 1}, {w}, false}};
+
+ for (auto &tc : test_cases)
+ {
+ MeanGraph g;
+ g.keep_dims(false);
+ g.axes(tc.nchw_ind);
+ g.shape(tc.shape);
+ g.init();
+
+ run_phase(&g.g, false, true);
+
+ check_pre_trans(g.mean->input());
+
+ auto mean_succs = loco::succs(g.mean);
+ EXPECT_EQ(1, mean_succs.size());
+ if (tc.needs_transpose)
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*mean_succs.begin()));
+ }
+ else
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*mean_succs.begin()));
+ }
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
+ for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
+ {
+ EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
+ }
+ }
+}
+
+TEST(ConvertNCHWToNHWC, ConvertNCHWToNHWC_Mean_keep_dims_false_NEG)
+{
+ loco::Graph g;
+ auto input = g.nodes()->create<luci::CircleInput>();
+ auto output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(loco::DataType::FLOAT32);
+ input->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+ graph_output->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ graph_input->shape({channel_size, 4, 4});
+ input->shape({channel_size, 4, 4});
+ output->shape({channel_size});
+ graph_output->shape({channel_size});
+
+ auto mean = g.nodes()->create<luci::CircleMean>();
+ auto rindices = g.nodes()->create<luci::CircleConst>();
+
+ mean->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ mean->shape({channel_size});
+ rindices->shape({2});
+
+ rindices->size<loco::DataType::S32>(2);
+ rindices->at<loco::DataType::S32>(0) = 1;
+ rindices->at<loco::DataType::S32>(1) = 2;
+
+ mean->input(input);
+ mean->reduction_indices(rindices);
+ mean->keep_dims(false);
+
+ mean->name("mean");
+ rindices->name("rindices");
+
+ output->from(mean);
+
+ run_phase(&g, true, true);
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, Minimum)
+{
+ MinimumGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.min->x());
+
+ auto min_succs = loco::succs(g.min);
+ EXPECT_EQ(1, min_succs.size());
+ check_post_trans(*min_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Mul)
{
MulGraph g;
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, MulScalar)
+{
+ MulScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.mul->x());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
+ EXPECT_NE(nullptr, new_multiplier);
+ EXPECT_EQ(1, new_multiplier->rank());
+ EXPECT_EQ(1, new_multiplier->dim(0).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, MulBothNorm)
+{
+ MulBothNormGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.mul->x());
+ check_pre_trans(g.mul->y());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ check_pre_trans(g.output->from());
+}
+
TEST(ConvertNCHWToNHWC, Neg)
{
NegGraph g;
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, PadV2)
+{
+ PadV2Graph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.pad->input());
+
+ auto pad_succs = loco::succs(g.pad);
+ EXPECT_EQ(1, pad_succs.size());
+ check_post_trans(*pad_succs.begin());
+
+ auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
+ EXPECT_NE(nullptr, new_paddings);
+ EXPECT_EQ(2, new_paddings->rank());
+ EXPECT_EQ(4, new_paddings->dim(0).value());
+ EXPECT_EQ(2, new_paddings->dim(1).value());
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
+}
+
TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
{
AddGraph g;
EXPECT_EQ(4, g.relu6->dim(2).value());
EXPECT_EQ(16, g.relu6->dim(3).value());
}
+
+TEST(ConvertNCHWToNHWC, Rsqrt)
+{
+ RsqrtGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rsqrt->x());
+
+ auto rsqrt_succs = loco::succs(g.rsqrt);
+ EXPECT_EQ(1, rsqrt_succs.size());
+ check_post_trans(*rsqrt_succs.begin());
+
+ // Check rsqrt shape
+ EXPECT_EQ(1, g.rsqrt->dim(0).value());
+ EXPECT_EQ(4, g.rsqrt->dim(1).value());
+ EXPECT_EQ(4, g.rsqrt->dim(2).value());
+ EXPECT_EQ(16, g.rsqrt->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, SquaredDifference)
+{
+ SquaredDifferenceGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.sqdiff->x());
+ check_pre_trans(g.sqdiff->y());
+
+ auto sqdiff_succs = loco::succs(g.sqdiff);
+ EXPECT_EQ(1, sqdiff_succs.size());
+ check_post_trans(*sqdiff_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, Sub)
+{
+ SubGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.sub->x());
+
+ auto add_succs = loco::succs(g.sub);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(4, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(1, new_beta->dim(1).value());
+ EXPECT_EQ(1, new_beta->dim(2).value());
+ EXPECT_EQ(channel_size, new_beta->dim(3).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, SubScalar)
+{
+ SubScalarGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.sub->y());
+
+ auto add_succs = loco::succs(g.sub);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.sub->x());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(1, new_beta->rank());
+
+ check_pre_trans(g.output->from());
+}
return false;
if (add->dtype() != loco::DataType::FLOAT32)
return false;
- // TODO support more Activations
- if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
- add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
- return false;
// get weight of dwconv
auto filter = dynamic_cast<luci::CircleConst *>(dwconv->filter());
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/CircleNodeClone.h>
#include <cassert>
#include <set>
return node->dim(axis).value() == depth;
}
-/// @return true if node shape consists of ones, except the one before the last dim: 1,...1,depth,1
-bool is_quasi_1D_with_dummy_dim(luci::CircleConst *node, uint32_t depth)
-{
- auto rank = node->rank();
- // minimal accepted shape is [1 x depth x 1]
- if (rank < 3)
- return false;
- const auto depth_axis = rank - 2;
- for (uint32_t axis = 0; axis < rank; ++axis)
- {
- if (axis != depth_axis && node->dim(axis).value() != 1)
- return false;
- }
- return node->dim(depth_axis).value() == depth;
-}
-
-bool is_instance_mean_v0(luci::CircleMean *mean)
+bool is_instance_mean_v1(luci::CircleMean *mean)
{
//
// CHECK 1) input is rank 4
return mean->keep_dims();
}
-bool is_instance_mean_v1(luci::CircleMean *mean)
-{
- //
- // CHECK 1) input is rank 5 (NHWCX)
- //
- auto input = loco::must_cast<luci::CircleNode *>(mean->input());
- if (input->shape_status() != luci::ShapeStatus::VALID)
- return false;
- if (input->rank() != 5)
- return false;
-
- //
- // CHECK 2) 'reduction indices' is CircleConst of value [1,2,4], that is HWX of NHWCX input shape
- //
- // TODO Support equivalent case, like [-3,-2]
- // TODO Support non-Const case?
- // TODO What if input is NCHW format in Circle?
- auto red_indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
- if (not red_indices)
- return false;
- if (red_indices->rank() != 1)
- return false;
- std::set<int32_t> red_indices_set;
-
- // TODO Currently only support S32, support other types
- if (red_indices->dtype() != loco::DataType::S32)
- return false;
- for (uint32_t i = 0; i < red_indices->dim(0).value(); ++i)
- red_indices_set.insert(red_indices->at<loco::DataType::S32>(i));
-
- if (red_indices_set.size() != 3)
- return false;
- if (red_indices_set.find(1) == red_indices_set.end())
- return false;
- if (red_indices_set.find(2) == red_indices_set.end())
- return false;
- if (red_indices_set.find(4) == red_indices_set.end())
- return false;
-
- //
- // CHECK 3) keep_dims == true (?)
- //
- // We only have case of 'keep_dims == true' so far, but it might be okay with 'keep_dims == false'
- // TODO Check this fact, and if true, return true regardless of keep_dims
- return mean->keep_dims();
-}
-
/// @return true When node has the shape of 1D channel_size
bool is_1D_float32_const(const luci::CircleConst *node, uint32_t channel_size)
{
*
* TODO support other semantically same patterns for instance norm
*
+ * Version_1
* [In]
* |
* V
* V
* [Out]
*-------------------------------------------------------------------
- * [In]
- * |
- * V
- * ifm
- * |
- * V
- * +---------reshape_of_ifm ----+ (reduction indicies)
- * | | | |
- * | | V V
- * | | mean_of_reshape -------------+
- * | V | |
- * | sqdiff <--+ (reduction indicies) |
- * | | | |
- * | V | |
- * | mean_as_variance <---+ const_as_epsilon |
- * | | | |
- * | V | |
- * | add_as_variance <--------+ |
- * | | |
- * | V |
- * | rsqrt const_as_gamma |
- * | | | |
- * | V | |
- * | mul_gamma <--+ |
- * | | | |
- * V V V |
- * mul_as_scaled_reshape mul_as_scaled_mean <-----------+
- * | |
- * | const_as_beta |
- * | | V
- * | +------> sub
- * V |
- * add_as_terminal <----------+
- * |
- * V
- * reshape_as_terminal
- * |
- * V
- * [Out]
- *-------------------------------------------------------------------
+ * Version_2
* [In]
* |
* V
* |
* V
* [Out]
+ *-------------------------------------------------------------------
+ * Version_3
+ * [In]
+ * |
+ * V
+ * +----+-------------- ifm ---+
+ * | | (reduction | | (reduction
+ * | | indicies) | | indicies)
+ * | | | | | |
+ * | V V | V V
+ * | mean_of_ifm | mean_of_ifm_2
+ * | | | |
+ * V | V |
+ * sub <----+ sub_2 <---+
+ * | |
+ * | V
+ * | square
+ * | | (reduction indicies)
+ * | | |
+ * | V |
+ * | mean_as_variance <---+
+ * | |
+ * | V
+ * | sqrt const_as_epsilon
+ * | | |
+ * | V |
+ * | add_as_variance <---+
+ * | |
+ * V |
+ * div <------------------+ const_as_gamma
+ * | |
+ * V |
+ * mul_gamma <----------------------+
+ * | const_as_beta
+ * V |
+ * add_as_terminal <--------+
+ * |
+ * V
+ * [Out]
+ *-------------------------------------------------------------------
+ * Version_4
+ * - mul_gamma and add_as_terminal are removed for const_as_gamma = 1.0
+ * and const_as_beta = 0.0
+ * [In]
+ * |
+ * V
+ * +----+-------------- ifm ---+
+ * | | (reduction | | (reduction
+ * | | indicies) | | indicies)
+ * | | | | | |
+ * | V V | V V
+ * | mean_of_ifm | mean_of_ifm_2
+ * | | | |
+ * V | V |
+ * sub <----+ sub_2 <---+
+ * | |
+ * | V
+ * | square
+ * | | (reduction indicies)
+ * | | |
+ * | V |
+ * | mean_as_variance <---+
+ * | |
+ * | V
+ * | sqrt const_as_epsilon
+ * | | |
+ * | V |
+ * | add_as_variance <---+
+ * | |
+ * V |
+ * div <------------------+
+ * |
+ * V
+ * [Out]
+ *-------------------------------------------------------------------
+ * Version_5
+ * [In]
+ * |
+ * V
+ * +----------- ifm -----+ (reduction indicies)
+ * | | | |
+ * | | V V
+ * | | mean_of_ifm ----------------+
+ * | V | |
+ * | sqdiff <--+ (reduction indicies) |
+ * | | | |
+ * | V | |
+ * | mean_as_variance <---+ const_as_epsilon |
+ * | | | |
+ * | V | |
+ * | add_as_variance <--------+ |
+ * | | |
+ * | V |
+ * | rsqrt |
+ * | | |
+ * | +--+--+ |
+ * | | | |
+ * V V V |
+ * mul_as_scaled_ifm mul_as_scaled_mean <-------------+
+ * | |
+ * | const_as_beta |
+ * | | V
+ * | +------> sub
+ * V |
+ * add_as_terminal <----------+
+ * |
+ * V
+ * [Out]
*/
class InstanceNormPattern final
{
public:
enum PatternVersion
{
- Version_0,
+ Version_Unknown,
Version_1,
Version_2,
+ Version_3,
+ Version_4,
+ Version_5,
};
InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
_pv = pv;
}
+ InstanceNormPattern(luci::CircleDiv *candidate, PatternVersion pv)
+ {
+ assert(candidate);
+ div = candidate;
+ _pv = pv;
+ }
+
+private:
+ template <enum PatternVersion> bool match();
+
public:
bool matched();
bool matched() const { return _matched; }
loco::Node *ifm = nullptr;
luci::CircleReshape *reshape_of_ifm = nullptr;
luci::CircleMean *mean_of_ifm = nullptr;
+ luci::CircleMean *mean_of_ifm_2 = nullptr;
luci::CircleMean *mean_of_reshape = nullptr;
luci::CircleSquaredDifference *sqdiff = nullptr;
+ luci::CircleSquare *square = nullptr;
luci::CircleMean *mean_as_variance = nullptr;
luci::CircleConst *const_as_epsilon = nullptr;
luci::CircleAdd *add_as_variance = nullptr;
luci::CircleMul *mul_as_scaled_reshape = nullptr;
luci::CircleConst *const_as_beta = nullptr;
luci::CircleSub *sub = nullptr;
+ luci::CircleSub *sub_2 = nullptr;
luci::CircleAdd *add_as_terminal = nullptr;
luci::CirclePow *pow = nullptr;
+ luci::CircleSqrt *sqrt = nullptr;
luci::CircleDiv *div = nullptr;
private:
PatternVersion _pv;
};
-bool InstanceNormPattern::matched()
-{
- if (_matched)
- return true;
-
#define CHECK_OR_FALSE(condition) \
if (not(condition)) \
return false;
- // Check order is DFS
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
+{
+ CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
- // Version 2 is quite different from Version 0 and 1.
- // So it is handled in the separate if statement
- if (_pv == PatternVersion::Version_2)
- {
- CHECK_OR_FALSE(
- luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
+ CHECK_OR_FALSE(ifm_circle->rank() == 4);
+ CHECK_OR_FALSE(ifm_circle->dim(3).known());
+ uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
- sub = dynamic_cast<luci::CircleSub *>(div->x());
- CHECK_OR_FALSE(sub);
+ CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
- ifm = sub->x();
- CHECK_OR_FALSE(ifm);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
- luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
- CHECK_OR_FALSE(ifm_node->rank() == 4);
- CHECK_OR_FALSE(ifm_node->dim(3).known());
- uint32_t ifm_channel_depth = ifm_node->dim(3).value();
+ add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
+ CHECK_OR_FALSE(add_as_variance);
- mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
- CHECK_OR_FALSE(mean_of_ifm);
+ CHECK_OR_FALSE(
+ luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
- CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
- pow = dynamic_cast<luci::CirclePow *>(div->y());
- CHECK_OR_FALSE(pow);
+ CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
- add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
- CHECK_OR_FALSE(add_as_variance);
+ sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
+ CHECK_OR_FALSE(sqdiff);
- luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
- CHECK_OR_FALSE(zero_point_five);
- CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
- // TODO Support regarding broadcast
- CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
- CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
+ loco::Node *ifm_should_be = nullptr;
+ CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(is_instance_mean_v1(mean_of_ifm));
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
- CHECK_OR_FALSE(
- luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
- CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
- // TODO Support regarding broadcast
- CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+ const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
+ CHECK_OR_FALSE(const_as_beta);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
- CHECK_OR_FALSE(is_instance_mean_v0(mean_as_variance));
+ luci::CircleMul *mul_gamma_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_should_be = nullptr;
- sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
- CHECK_OR_FALSE(sqdiff);
+ mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
+ CHECK_OR_FALSE(mul_as_scaled_mean);
+ CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
- loco::Node *ifm_should_be = nullptr;
- luci::CircleMean *mean_of_ifm_should_be = nullptr;
- CHECK_OR_FALSE(
- luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
- CHECK_OR_FALSE(ifm == ifm_should_be);
- CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
+ _matched = true;
+ return true;
+}
- // Check for channel size
- CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
- CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_2>()
+{
+ CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
- _matched = true;
- return true;
- }
+ sub = dynamic_cast<luci::CircleSub *>(div->x());
+ CHECK_OR_FALSE(sub);
- if (_pv == PatternVersion::Version_0)
- {
- CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
- }
- if (_pv == PatternVersion::Version_1)
- {
- CHECK_OR_FALSE(
- luci::fill(&mul_as_scaled_reshape, &sub).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(
- luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_reshape));
- ifm = reshape_of_ifm->tensor();
- }
+ ifm = sub->x();
+ CHECK_OR_FALSE(ifm);
+
+ luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_node->rank() == 4);
+ CHECK_OR_FALSE(ifm_node->dim(3).known());
+ uint32_t ifm_channel_depth = ifm_node->dim(3).value();
+
+ mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
+ CHECK_OR_FALSE(mean_of_ifm);
+
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ pow = dynamic_cast<luci::CirclePow *>(div->y());
+ CHECK_OR_FALSE(pow);
+
+ add_as_variance = dynamic_cast<luci::CircleAdd *>(pow->x());
+ CHECK_OR_FALSE(add_as_variance);
+
+ luci::CircleConst *zero_point_five = dynamic_cast<luci::CircleConst *>(pow->y());
+ CHECK_OR_FALSE(zero_point_five);
+ CHECK_OR_FALSE(zero_point_five->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(zero_point_five->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
+
+ CHECK_OR_FALSE(
+ luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
+
+ sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
+ CHECK_OR_FALSE(sqdiff);
+
+ loco::Node *ifm_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_should_be = nullptr;
+ CHECK_OR_FALSE(
+ luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
+
+ // Check for channel size
+ CHECK_OR_FALSE(is_1D_float32_const(const_as_gamma, ifm_channel_depth));
+ CHECK_OR_FALSE(is_1D_float32_const(const_as_beta, ifm_channel_depth));
+
+ _matched = true;
+ return true;
+}
+
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_3>()
+{
+ CHECK_OR_FALSE(luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
+
+ // check left sub
+ ifm = sub->x();
+ CHECK_OR_FALSE(ifm);
+
+ luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_node->rank() == 4);
+ CHECK_OR_FALSE(ifm_node->dim(3).known());
+
+ mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
+ CHECK_OR_FALSE(mean_of_ifm);
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ // continue search from add_as_variance
+ CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
+ CHECK_OR_FALSE(mean_as_variance);
+
+ square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
+ CHECK_OR_FALSE(square);
+
+ sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
+ CHECK_OR_FALSE(sub_2);
+ CHECK_OR_FALSE(ifm == sub_2->x());
+
+ mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
+ CHECK_OR_FALSE(mean_of_ifm_2);
+ CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
+
+ loco::Node *ifm_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
+ CHECK_OR_FALSE(
+ luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
+
+ _matched = true;
+ return true;
+}
+
+luci::CircleConst *make_const_one(loco::Graph *graph, float value)
+{
+ auto const_one = graph->nodes()->create<luci::CircleConst>();
+ const_one->dtype(loco::DataType::FLOAT32);
+ const_one->rank(1);
+ const_one->size<loco::DataType::FLOAT32>(1);
+ const_one->at<loco::DataType::FLOAT32>(0) = value;
+ return const_one;
+}
+
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_4>()
+{
+ CHECK_OR_FALSE(div);
+ CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
+
+ // check left sub
+ ifm = sub->x();
+ CHECK_OR_FALSE(ifm);
+
+ luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_node->rank() == 4);
+ CHECK_OR_FALSE(ifm_node->dim(3).known());
+
+ mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
+ CHECK_OR_FALSE(mean_of_ifm);
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ // continue search from add_as_variance
+ CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
+ CHECK_OR_FALSE(mean_as_variance);
+
+ square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
+ CHECK_OR_FALSE(square);
+
+ sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
+ CHECK_OR_FALSE(sub_2);
+ CHECK_OR_FALSE(ifm == sub_2->x());
+
+ mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
+ CHECK_OR_FALSE(mean_of_ifm_2);
+ CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
+
+ loco::Node *ifm_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
+ CHECK_OR_FALSE(
+ luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
+
+ assert(const_as_gamma == nullptr);
+ assert(const_as_beta == nullptr);
+ assert(mul_gamma == nullptr);
+ assert(add_as_terminal == nullptr);
+
+ // create 1.0 gamma and 0.0 beta
+ auto graph = div->graph();
+ const_as_gamma = make_const_one(graph, 1.0f);
+ const_as_beta = make_const_one(graph, 0.0f);
+ const_as_gamma->name(div->name() + "/gamma");
+ const_as_beta->name(div->name() + "/beta");
+
+ _matched = true;
+ return true;
+}
+
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
+{
+ CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));
auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
CHECK_OR_FALSE(ifm_circle->dim(3).known());
uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
- CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
-
- if (_pv == PatternVersion::Version_0)
- {
- CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
- }
- if (_pv == PatternVersion::Version_1)
- {
- CHECK_OR_FALSE(is_quasi_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
- }
-
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
CHECK_OR_FALSE(add_as_variance);
// TODO Support regarding broadcast
CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
- if (_pv == PatternVersion::Version_0)
- {
- CHECK_OR_FALSE(is_instance_mean_v0(mean_as_variance));
- }
- if (_pv == PatternVersion::Version_1)
- {
- CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
- }
+ CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
CHECK_OR_FALSE(sqdiff);
- if (_pv == PatternVersion::Version_0)
- {
- loco::Node *ifm_should_be = nullptr;
- CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
- CHECK_OR_FALSE(ifm == ifm_should_be);
- CHECK_OR_FALSE(is_instance_mean_v0(mean_of_ifm));
- CHECK_OR_FALSE(ifm == mean_of_ifm->input());
- }
- if (_pv == PatternVersion::Version_1)
- {
- loco::Node *reshape_should_be = nullptr;
- CHECK_OR_FALSE(
- luci::fill(&reshape_should_be, &mean_of_reshape).with_commutative_args_of(sqdiff));
- CHECK_OR_FALSE(reshape_of_ifm == reshape_should_be);
- CHECK_OR_FALSE(is_instance_mean_v1(mean_of_reshape));
- CHECK_OR_FALSE(reshape_of_ifm == mean_of_reshape->input());
- }
+ loco::Node *ifm_should_be = nullptr;
+ CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(is_instance_mean_v1(mean_of_ifm));
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
CHECK_OR_FALSE(const_as_beta);
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
- if (_pv == PatternVersion::Version_0)
- {
- CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
- }
- if (_pv == PatternVersion::Version_1)
- {
- CHECK_OR_FALSE(is_quasi_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
- }
+ luci::CircleRsqrt *rsqrt_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_should_be = nullptr;
mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
CHECK_OR_FALSE(mul_as_scaled_mean);
+ CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
+ CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
- luci::CircleMul *mul_gamma_should_be = nullptr;
- luci::CircleMean *mean_of_ifm_should_be = nullptr;
- luci::CircleMean *mean_of_reshape_should_be = nullptr;
+ // mul_gamma is absent
+ // const_as_gamma assume to be 1.0
+ auto graph = add_as_terminal->graph();
+ const_as_gamma = make_const_one(graph, 1.0f);
+ const_as_gamma->name(add_as_terminal->name() + "/gamma");
- if (_pv == PatternVersion::Version_0)
- {
- CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
- .with_commutative_args_of(mul_as_scaled_mean));
- CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
- CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
- }
- if (_pv == PatternVersion::Version_1)
+ _matched = true;
+ return true;
+}
+
+bool InstanceNormPattern::matched()
+{
+ if (_matched)
+ return true;
+
+ // Check order is DFS
+
+ switch (_pv)
{
- CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_reshape_should_be)
- .with_commutative_args_of(mul_as_scaled_mean));
- CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
- CHECK_OR_FALSE(mean_of_reshape == mean_of_reshape_should_be);
+ case PatternVersion::Version_1:
+ return match<PatternVersion::Version_1>();
+ case PatternVersion::Version_2:
+ return match<PatternVersion::Version_2>();
+ case PatternVersion::Version_3:
+ return match<PatternVersion::Version_3>();
+ case PatternVersion::Version_4:
+ return match<PatternVersion::Version_4>();
+ case PatternVersion::Version_5:
+ return match<PatternVersion::Version_5>();
+
+ default:
+ break;
}
-#undef CHECK_OR_FALSE
- _matched = true;
- return true;
+ throw std::runtime_error("Invalid InstanceNorm PatternVersion.");
}
+#undef CHECK_OR_FALSE
+
/**
* Instance norm pattern would be fused like following diagram:
*
- * [In] --------------------------- CircleInstanceNorm --- [Out]
- * / /
- * const_as_gamma --- TFLReshape --- /
- * /
- * const_as_beta ---- TFLReshape ---
+ * [In] -------------- CircleInstanceNorm --- [Out]
+ * / /
+ * const_as_gamma ---- /
+ * /
+ * const_as_beta -----
*
* Note
* - 'const_as_gamma' and 'const_as_beta' are from original graph
* - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute
- * - TFLReshape is added as CircleInstanceNorm only accept 1D tensor
+ * - Two CircleConst shape is updated as CircleInstanceNorm only accept 1D tensor
* - 'CircleConst --- TFLReshape' is expected to be fused in constant folding for Reshape
*/
-void fuse_instance_norm(const InstanceNormPattern &p)
+
+class FuseInstanceNorm final
{
- assert(p.matched());
+public:
+ FuseInstanceNorm(const InstanceNormPattern &p) : _p(p) {}
- auto graph = p.add_as_terminal->graph();
+public:
+ void apply(void);
- // Version 0 and 1 need to reshape
- if (p.version() != InstanceNormPattern::Version_2)
+private:
+ template <InstanceNormPattern::PatternVersion> void apply(void);
+
+private:
+ void reshape_gamma_beta(void);
+ luci::CircleInstanceNorm *create_inst_norm(loco::Graph *graph);
+
+private:
+ const InstanceNormPattern &_p;
+};
+
+void FuseInstanceNorm::reshape_gamma_beta()
+{
+ // Version 1 and 3 need to reshape
{
- p.const_as_gamma->rank(1);
- p.const_as_gamma->dim(0).set(p.const_as_gamma->size<loco::DataType::FLOAT32>());
- p.const_as_beta->rank(1);
- p.const_as_beta->dim(0).set(p.const_as_beta->size<loco::DataType::FLOAT32>());
+ _p.const_as_gamma->rank(1);
+ _p.const_as_gamma->dim(0).set(_p.const_as_gamma->size<loco::DataType::FLOAT32>());
+ _p.const_as_beta->rank(1);
+ _p.const_as_beta->dim(0).set(_p.const_as_beta->size<loco::DataType::FLOAT32>());
- p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
- p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
+ _p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
+ _p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
}
+}
+luci::CircleInstanceNorm *FuseInstanceNorm::create_inst_norm(loco::Graph *graph)
+{
// Make Instance Norm to replace
auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
- instance_norm->input(p.ifm);
- instance_norm->gamma(p.const_as_gamma);
- instance_norm->beta(p.const_as_beta);
- float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
+ instance_norm->input(_p.ifm);
+ instance_norm->gamma(_p.const_as_gamma);
+ instance_norm->beta(_p.const_as_beta);
+ float epsilon = _p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
instance_norm->epsilon(epsilon);
- instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
- // NOTE unique name should be assigned in export
- instance_norm->name("InstanceNorm");
+ if (_p.add_as_terminal != nullptr)
+ {
+ instance_norm->fusedActivationFunction(_p.add_as_terminal->fusedActivationFunction());
+ // NOTE unique name should be assigned in export
+ instance_norm->name("FusedInstanceNorm/" + _p.add_as_terminal->name());
+ }
+ else
+ {
+ // VERSION_4
+ assert(_p.div != nullptr);
+ instance_norm->fusedActivationFunction(_p.div->fusedActivationFunction());
+ instance_norm->name("FusedInstanceNorm/" + _p.div->name());
+ }
+
+ return instance_norm;
+}
+
+template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_1>()
+{
+ auto graph = _p.add_as_terminal->graph();
+
+ reshape_gamma_beta();
+
+ auto instance_norm = create_inst_norm(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p.mean_of_ifm),
+ luci::get_origin(_p.sqdiff),
+ luci::get_origin(_p.mean_as_variance),
+ luci::get_origin(_p.add_as_variance),
+ luci::get_origin(_p.rsqrt),
+ luci::get_origin(_p.mul_gamma),
+ luci::get_origin(_p.mul_as_scaled_ifm),
+ luci::get_origin(_p.mul_as_scaled_mean),
+ luci::get_origin(_p.sub),
+ luci::get_origin(_p.add_as_terminal)};
+
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
+
+ replace(_p.add_as_terminal).with(instance_norm);
+}
+
+template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_2>()
+{
+ auto graph = _p.add_as_terminal->graph();
+
+ auto instance_norm = create_inst_norm(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p.mean_of_ifm),
+ luci::get_origin(_p.sqdiff),
+ luci::get_origin(_p.mean_as_variance),
+ luci::get_origin(_p.add_as_variance),
+ luci::get_origin(_p.pow),
+ luci::get_origin(_p.sub),
+ luci::get_origin(_p.div),
+ luci::get_origin(_p.mul_gamma),
+ luci::get_origin(_p.add_as_terminal)};
+
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
+
+ replace(_p.add_as_terminal).with(instance_norm);
+}
+
+template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_3>()
+{
+ auto graph = _p.add_as_terminal->graph();
+
+ reshape_gamma_beta();
+
+ auto instance_norm = create_inst_norm(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p.mean_of_ifm),
+ luci::get_origin(_p.sub),
+ luci::get_origin(_p.mean_of_ifm_2),
+ luci::get_origin(_p.sub_2),
+ luci::get_origin(_p.square),
+ luci::get_origin(_p.mean_as_variance),
+ luci::get_origin(_p.sqrt),
+ luci::get_origin(_p.add_as_variance),
+ luci::get_origin(_p.div),
+ luci::get_origin(_p.mul_gamma),
+ luci::get_origin(_p.add_as_terminal)};
+
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
+
+ replace(_p.add_as_terminal).with(instance_norm);
+}
+
+template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_4>()
+{
+ auto graph = _p.div->graph();
+
+ auto instance_norm = create_inst_norm(graph);
// set origin
std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
- luci::get_origin(p.sqdiff),
- luci::get_origin(p.mean_as_variance),
- luci::get_origin(p.add_as_variance),
- luci::get_origin(p.mul_gamma),
- luci::get_origin(p.sub),
- luci::get_origin(p.add_as_terminal)};
- if (p.version() == InstanceNormPattern::PatternVersion::Version_0)
+ luci::get_origin(_p.mean_of_ifm),
+ luci::get_origin(_p.sub),
+ luci::get_origin(_p.mean_of_ifm_2),
+ luci::get_origin(_p.sub_2),
+ luci::get_origin(_p.square),
+ luci::get_origin(_p.mean_as_variance),
+ luci::get_origin(_p.sqrt),
+ luci::get_origin(_p.add_as_variance),
+ luci::get_origin(_p.div)};
+
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
+
+ replace(_p.div).with(instance_norm);
+}
+
+template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
+{
+ auto graph = _p.add_as_terminal->graph();
+
+ reshape_gamma_beta();
+
+ auto instance_norm = create_inst_norm(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p.mean_of_ifm),
+ luci::get_origin(_p.sqdiff),
+ luci::get_origin(_p.mean_as_variance),
+ luci::get_origin(_p.add_as_variance),
+ luci::get_origin(_p.rsqrt),
+ luci::get_origin(_p.mul_as_scaled_ifm),
+ luci::get_origin(_p.mul_as_scaled_mean),
+ luci::get_origin(_p.sub),
+ luci::get_origin(_p.add_as_terminal)};
+
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
+
+ replace(_p.add_as_terminal).with(instance_norm);
+}
+
+void FuseInstanceNorm::apply()
+{
+ assert(_p.matched());
+
+ switch (_p.version())
{
- origin_vec.push_back(luci::get_origin(p.mean_of_ifm));
- origin_vec.push_back(luci::get_origin(p.rsqrt));
- origin_vec.push_back(luci::get_origin(p.mul_as_scaled_ifm));
- origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean));
+ case InstanceNormPattern::PatternVersion::Version_1:
+ apply<InstanceNormPattern::PatternVersion::Version_1>();
+ break;
+ case InstanceNormPattern::PatternVersion::Version_2:
+ apply<InstanceNormPattern::PatternVersion::Version_2>();
+ break;
+ case InstanceNormPattern::PatternVersion::Version_3:
+ apply<InstanceNormPattern::PatternVersion::Version_3>();
+ break;
+ case InstanceNormPattern::PatternVersion::Version_4:
+ apply<InstanceNormPattern::PatternVersion::Version_4>();
+ break;
+ case InstanceNormPattern::PatternVersion::Version_5:
+ apply<InstanceNormPattern::PatternVersion::Version_5>();
+ break;
+
+ default:
+ break;
}
- if (p.version() == InstanceNormPattern::PatternVersion::Version_1)
+}
+
+} // namespace
+
+namespace
+{
+
+class PostFusion final
+{
+public:
+ PostFusion(luci::CircleInstanceNorm *inst_norm) : _inst_norm(inst_norm) {}
+
+private:
+ uint32_t input_channel(void);
+
+ luci::CircleConst *match_const_channel(luci::CircleConst *, uint32_t);
+ bool match_const_gamma_channel(void);
+ bool match_const_beta_channel(void);
+
+public:
+ bool process(void);
+
+private:
+ luci::CircleInstanceNorm *_inst_norm = nullptr;
+};
+
+/**
+ * @brief return C value or 0 if shape status is not valid
+ */
+uint32_t PostFusion::input_channel(void)
+{
+ auto input = dynamic_cast<luci::CircleNode *>(_inst_norm->input());
+ if (input == nullptr)
+ return 0;
+ if (input->shape_status() != luci::ShapeStatus::VALID)
+ return 0;
+
+ auto input_rank = input->rank();
+ if (input_rank < 1)
+ return 0;
+
+ // assume channel-last
+ return input->dim(input_rank - 1).value();
+}
+
+/**
+ * @brief return new CircleConst with C channel if input_const channel != C
+ */
+luci::CircleConst *PostFusion::match_const_channel(luci::CircleConst *input_const, uint32_t C)
+{
+ luci::CircleConst *new_input_const = nullptr;
+
+ auto input_chn = input_const->dim(0).value();
+ if (input_chn == 1 && input_chn != C)
+ {
+ float value = input_const->at<loco::DataType::FLOAT32>(0);
+ auto clone = luci::clone_node(input_const, input_const->graph());
+
+ new_input_const = loco::must_cast<luci::CircleConst *>(clone);
+ new_input_const->rank(1);
+ new_input_const->dim(0).set(C);
+ new_input_const->size<loco::DataType::FLOAT32>(C);
+ for (uint32_t c = 0; c < C; ++c)
+ new_input_const->at<loco::DataType::FLOAT32>(c) = value;
+ }
+
+ return new_input_const;
+}
+
+/**
+ * @brief Broadcast gamma to match input channel if CircleConst
+ */
+bool PostFusion::match_const_gamma_channel(void)
+{
+ auto const_as_gamma = dynamic_cast<luci::CircleConst *>(_inst_norm->gamma());
+ if (const_as_gamma == nullptr)
+ return false;
+
+ auto C = input_channel();
+ if (C == 0)
+ return false;
+
+ auto new_const_as_gamma = match_const_channel(const_as_gamma, C);
+ if (new_const_as_gamma == nullptr)
+ return false;
+
+ _inst_norm->gamma(new_const_as_gamma);
+
+ return true;
+}
+
+/**
+ * @brief Broadcast beta to match input channel if CircleConst
+ */
+bool PostFusion::match_const_beta_channel(void)
+{
+ auto const_as_beta = dynamic_cast<luci::CircleConst *>(_inst_norm->beta());
+ if (const_as_beta == nullptr)
+ return false;
+
+ auto C = input_channel();
+ if (C == 0)
+ return false;
+
+ auto new_const_as_beta = match_const_channel(const_as_beta, C);
+ if (new_const_as_beta == nullptr)
+ return false;
+
+ _inst_norm->beta(new_const_as_beta);
+
+ return true;
+}
+
+bool PostFusion::process(void)
+{
+ bool changed = false;
+
+ if (match_const_gamma_channel())
+ changed = true;
+ if (match_const_beta_channel())
+ changed = true;
+
+ return changed;
+}
+
+} // namespace
+
+namespace
+{
+
+bool is_add_input_mul_const(luci::CircleAdd *add)
+{
+ luci::CircleMul *p_mul = nullptr;
+ luci::CircleConst *p_const = nullptr;
+
+ return luci::fill(&p_mul, &p_const).with_commutative_args_of(add);
+}
+
+bool fuse_instance_norm(luci::CircleAdd *add)
+{
+ InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_1;
+
+ if (is_add_input_mul_const(add))
+ pv = InstanceNormPattern::PatternVersion::Version_2;
+
+ InstanceNormPattern pattern(add, pv);
+ if (pattern.matched())
+ {
+ FuseInstanceNorm fuse(pattern);
+ fuse.apply();
+ return true;
+ }
+
+ if (pv == InstanceNormPattern::PatternVersion::Version_1)
{
- origin_vec.push_back(luci::get_origin(p.reshape_of_ifm));
- origin_vec.push_back(luci::get_origin(p.mean_of_reshape));
- origin_vec.push_back(luci::get_origin(p.rsqrt));
- origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean));
- origin_vec.push_back(luci::get_origin(p.mul_as_scaled_reshape));
+ // if Version_1 failed, try with Version_5
+ pv = InstanceNormPattern::PatternVersion::Version_5;
+ InstanceNormPattern pattern(add, pv);
+ if (pattern.matched())
+ {
+ FuseInstanceNorm fuse(pattern);
+ fuse.apply();
+ return true;
+ }
}
- if (p.version() == InstanceNormPattern::PatternVersion::Version_2)
+ else if (pv == InstanceNormPattern::PatternVersion::Version_2)
{
- origin_vec.push_back(luci::get_origin(p.mean_of_ifm));
- origin_vec.push_back(luci::get_origin(p.pow));
- origin_vec.push_back(luci::get_origin(p.div));
+ // if Version_2 failed, try with Version_3
+ pv = InstanceNormPattern::PatternVersion::Version_3;
+ InstanceNormPattern pattern(add, pv);
+ if (pattern.matched())
+ {
+ FuseInstanceNorm fuse(pattern);
+ fuse.apply();
+ return true;
+ }
}
- luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
- replace(p.add_as_terminal).with(instance_norm);
+ return false;
+}
+
+bool fuse_instance_norm(luci::CircleDiv *div)
+{
+ InstanceNormPattern::PatternVersion pv = InstanceNormPattern::PatternVersion::Version_4;
+
+ InstanceNormPattern pattern(div, pv);
+ if (pattern.matched())
+ {
+ FuseInstanceNorm fuse(pattern);
+ fuse.apply();
+ return true;
+ }
+
+ return false;
+}
+
+bool post_fusion(luci::CircleInstanceNorm *inst_norm)
+{
+ PostFusion postfusion(inst_norm);
+
+ return postfusion.process();
}
} // namespace
bool FuseInstanceNormPass::run(loco::Graph *g)
{
bool changed = false;
- luci::CircleAdd *add;
- InstanceNormPattern::PatternVersion pv;
+ // Check Version_1, Version_2, Version_3, Version_5
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto reshape = dynamic_cast<luci::CircleReshape *>(node);
- if (not reshape)
- {
- add = dynamic_cast<luci::CircleAdd *>(node);
- if (not add)
- continue;
- pv = InstanceNormPattern::PatternVersion::Version_0;
-
- auto x = loco::must_cast<luci::CircleNode *>(add->x());
- auto y = loco::must_cast<luci::CircleNode *>(add->y());
- if ((x->opcode() == luci::CircleOpcode::MUL &&
- y->opcode() == luci::CircleOpcode::CIRCLECONST) ||
- (x->opcode() == luci::CircleOpcode::CIRCLECONST &&
- y->opcode() == luci::CircleOpcode::MUL))
- pv = InstanceNormPattern::PatternVersion::Version_2;
- }
- else
- {
- add = dynamic_cast<luci::CircleAdd *>(reshape->tensor());
- if (not add)
- continue;
- pv = InstanceNormPattern::PatternVersion::Version_1;
- }
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
- InstanceNormPattern pattern(add, pv);
- if (not pattern.matched())
+ if (fuse_instance_norm(add))
+ changed = true;
+ }
+
+ // Check Version_4(from DIV) if MUL-ADD pattern is not found
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto div = dynamic_cast<luci::CircleDiv *>(node);
+ if (not div)
continue;
- fuse_instance_norm(pattern);
- changed = true;
+ if (fuse_instance_norm(div))
+ changed = true;
+ }
+
+ // Post processing of FuseInstanceNorm
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto inst_norm = dynamic_cast<luci::CircleInstanceNorm *>(node);
+ if (not inst_norm)
+ continue;
+
+ if (post_fusion(inst_norm))
+ changed = true;
}
return changed;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
-
-TEST(FuseInstanceNormPass, is_quasi_1D_with_dummy_dim)
-{
- luci::CircleConst const_node;
-
- setShape(const_node, {});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {1});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {8});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {1, 2, 1, 8, 1});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {8, 3});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {8, 1});
- EXPECT_FALSE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {1, 8, 1});
- EXPECT_TRUE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-
- setShape(const_node, {1, 1, 1, 8, 1});
- EXPECT_TRUE(is_quasi_1D_with_dummy_dim(&const_node, 8));
-}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseMeanWithMeanPass.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+/**
+ * Fuse two Mean operations to one Mean operation with merged reduction indices
+ *
+ * BEFORE
+ * |
+ * [CircleMean, axis<1>]
+ * |
+ * [CircleMean, axis<1>]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleMean, axis<1,2>] [CircleMean, axis<1>]
+ * | |
+ * [CircleMean, axis<1>]
+ *
+ */
+luci::CircleConst *create_fused_indices(luci::CircleConst *indices,
+ const std::set<uint32_t> &indices_set)
+{
+ auto name = indices->name();
+
+ auto fused_indices_const = indices->graph()->nodes()->create<luci::CircleConst>();
+ fused_indices_const->dtype(indices->dtype());
+ fused_indices_const->rank(1);
+ fused_indices_const->dim(0) = indices_set.size();
+ fused_indices_const->size<loco::DataType::S32>(indices_set.size());
+ fused_indices_const->shape_status(luci::ShapeStatus::VALID);
+ fused_indices_const->name(name);
+
+ auto curr_index = 0;
+ for (auto it = indices_set.begin(); it != indices_set.end(); it++)
+ {
+ fused_indices_const->at<loco::DataType::S32>(curr_index) = *it;
+ curr_index++;
+ }
+
+ return fused_indices_const;
+}
+
+bool fuse_mean_with_mean(luci::CircleMean *mean)
+{
+ // Get reduction indices of current CircleMean operation.
+ auto indices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
+ if (not indices)
+ return false;
+ assert(indices->dtype() == loco::DataType::S32);
+
+ // Check whether previous node is CircleMean operation or not.
+ auto prev_mean = dynamic_cast<luci::CircleMean *>(mean->input());
+ if (not prev_mean)
+ return false;
+
+ // Check whether input rank of previous CircleMean operation is less 2 or not.
+ // This optimization works only if doesn't.
+ auto input = loco::must_cast<luci::CircleNode *>(prev_mean->input());
+ if (input->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+ auto input_rank = input->rank();
+ if (input_rank < 2)
+ return false;
+
+ // Check whether current CircleMean and previous CircleMean
+ // has the same keep_dims parameter or not.
+ // If it doesn't, keep the graph unchanged.
+ if (mean->keep_dims() != prev_mean->keep_dims())
+ return false;
+
+ // Get reduction indices of previous CircleMean operation.
+ auto prev_indices = dynamic_cast<luci::CircleConst *>(prev_mean->reduction_indices());
+ if (not prev_indices)
+ return false;
+ assert(prev_indices->dtype() == loco::DataType::S32);
+
+ // Get sizes of indices of current CircleMean operation and previous CircleMean operation.
+ auto indices_size = indices->size<loco::DataType::S32>();
+ auto prev_indices_size = prev_indices->size<loco::DataType::S32>();
+
+ // Get set of indices of previous CircleMean operation.
+ std::set<uint32_t> indices_set;
+ for (uint32_t i = 0; i < prev_indices_size; i++)
+ {
+ auto index = prev_indices->at<loco::DataType::S32>(i);
+ if (index < 0)
+ index += input_rank;
+ indices_set.insert(index);
+ }
+
+ // Get the vector of input indexes, that remained untouched
+ // after the current CircleMean operation.
+ std::vector<uint32_t> input_indices_vector;
+ for (uint32_t i = 0; i < input_rank; i++)
+ {
+ if (indices_set.find(i) == indices_set.end())
+ input_indices_vector.push_back(i);
+ }
+
+ // Get final set of merged indices.
+ for (uint32_t i = 0; i < indices_size; i++)
+ {
+ auto index = indices->at<loco::DataType::S32>(i);
+ if (index < 0)
+ index += input_rank;
+ indices_set.insert(input_indices_vector.at(index));
+ }
+
+ // Create merged indices.
+ auto fused_indices_const = create_fused_indices(indices, indices_set);
+
+ auto name = mean->name();
+ assert(name.length() > 0);
+
+ // Create and configure new CircleMean operation.
+ auto fused_mean = mean->graph()->nodes()->create<luci::CircleMean>();
+ fused_mean->reduction_indices(fused_indices_const);
+ fused_mean->input(prev_mean->input());
+ fused_mean->keep_dims(mean->keep_dims());
+ fused_mean->name(name + "/Mean");
+
+ // Replace old CircleMean operations with new CircleMean operation with merged indices.
+ replace(mean).with(fused_mean);
+ luci::add_origin(fused_mean,
+ luci::composite_origin({luci::get_origin(mean), luci::get_origin(prev_mean)}));
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseMeanWithMeanPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto mean = dynamic_cast<luci::CircleMean *>(node);
+ if (not mean)
+ continue;
+
+ if (fuse_mean_with_mean(mean))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseMeanWithMeanPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ * |
+ * [CircleMean, axis<1>]
+ * |
+ * [CircleMean, axis<1>]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleMean, axis<1,2>] [CircleMean, axis<1>]
+ * | |
+ * [CircleMean, axis<1>]
+ *
+ */
+class MeansGraphlet
+{
+public:
+ MeansGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _mean1 = g->nodes()->create<luci::CircleMean>();
+ _mean2 = g->nodes()->create<luci::CircleMean>();
+ _indices1 = g->nodes()->create<luci::CircleConst>();
+ _indices2 = g->nodes()->create<luci::CircleConst>();
+
+ _mean1->name("mean1");
+ _mean2->name("mean2");
+ _indices1->name("indices1");
+ _indices2->name("indices2");
+ }
+
+public:
+ luci::CircleMean *mean1() { return _mean1; }
+ luci::CircleMean *mean2() { return _mean2; }
+
+protected:
+ luci::CircleMean *_mean1 = nullptr;
+ luci::CircleMean *_mean2 = nullptr;
+ luci::CircleConst *_indices1 = nullptr;
+ luci::CircleConst *_indices2 = nullptr;
+};
+
+class FuseMeanWithMeanTestGraph : public TestIOGraph, public MeansGraphlet
+{
+public:
+ FuseMeanWithMeanTestGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1, 64, 20, 32}, {1, 20});
+ MeansGraphlet::init(g());
+
+ _indices1->rank(1);
+ _indices1->dtype(loco::DataType::S32);
+ _indices1->size<loco::DataType::S32>(1);
+ _indices1->at<loco::DataType::S32>(0) = static_cast<int32_t>(1);
+ _indices1->shape_status(luci::ShapeStatus::VALID);
+
+ _indices2->rank(1);
+ _indices2->dtype(loco::DataType::S32);
+ _indices2->size<loco::DataType::S32>(1);
+ _indices2->at<loco::DataType::S32>(0) = static_cast<int32_t>(2);
+ _indices2->shape_status(luci::ShapeStatus::VALID);
+
+ _mean1->input(input());
+ _mean1->reduction_indices(_indices1);
+
+ _mean2->input(_mean1);
+ _mean2->reduction_indices(_indices2);
+
+ output()->from(_mean2);
+ }
+};
+
+} // namespace
+
+TEST(FuseMeanWithMeanPassTest, name)
+{
+ luci::FuseMeanWithMeanPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(FuseMeanWithMeanPassTest, fuse_mean_with_mean)
+{
+ FuseMeanWithMeanTestGraph g;
+ luci::FuseMeanWithMeanPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+TEST(FuseMeanWithMeanPassTest, fus_mean_with_mean_NEG)
+{
+ FuseMeanWithMeanTestGraph g;
+ luci::FuseMeanWithMeanPass pass;
+
+ g.init();
+
+ // Add CircleRelu operation between CircleMeans operations
+ auto relu = g.g()->nodes()->create<luci::CircleRelu>();
+ relu->name("relu");
+ relu->features(g.mean1());
+ g.mean2()->input(relu);
+
+ // Due to the CircleRelu operation, pass will not be applied
+ EXPECT_FALSE(pass.run(g.g()));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseTransposeWithMeanPass.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/Nodes/CircleConst.h>
+
+namespace
+{
+
+/**
+ * Fuse Transpose with Mean if possible
+ *
+ * BEFORE
+ * |
+ * [CircleTranspose, perm<0, 2, 3, 1>]
+ * |
+ * [CircleMean, axis<3>]
+ * |
+ *
+ * AFTER
+ * | |
+ * [CircleMean, axis<1>] [CircleTranspose, perm<0, 2, 3, 1>]
+ * | |
+ * [CircleMean, axis<3>]
+ *
+ */
+
+/**
+ * @brief Create a const for fused reduction indices
+ */
+luci::CircleConst *create_fused_indices(luci::CircleConst *rindices,
+ const std::vector<uint32_t> &fused_rindices)
+{
+ assert(rindices != nullptr); // FIX_CALLER_UNLESS
+
+ if (rindices->dtype() != loco::DataType::S32)
+ return nullptr;
+
+ assert(fused_rindices.size() == rindices->size<loco::DataType::S32>());
+
+ auto fused_rindices_const = luci::clone(rindices);
+ auto name = rindices->name();
+ assert(name.length() > 0); // FIX_CALLER_UNLESS
+ fused_rindices_const->name(name + "_fused");
+
+ for (uint32_t i = 0; i < fused_rindices.size(); ++i)
+ {
+ fused_rindices_const->at<loco::DataType::S32>(i) = fused_rindices.at(i);
+ }
+
+ return fused_rindices_const;
+}
+
+bool const_has_value_s32(const luci::CircleConst *circle_const, int32_t value)
+{
+ if (circle_const->dtype() != loco::DataType::S32)
+ return false;
+
+ uint32_t size = circle_const->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ if (circle_const->at<loco::DataType::S32>(i) == value)
+ return true;
+ }
+
+ return false;
+}
+
+bool fuse_transpose_with_mean(luci::CircleMean *mean)
+{
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(mean->input());
+ if (not transpose)
+ return false;
+
+ // Get reduction indices of CircleMean operation.
+ auto rindices = dynamic_cast<luci::CircleConst *>(mean->reduction_indices());
+ if (not rindices)
+ return false;
+
+ if (rindices->dtype() != loco::DataType::S32)
+ return false;
+
+ if (mean->keep_dims() != false)
+ return false;
+
+ auto perm = dynamic_cast<luci::CircleConst *>(transpose->perm());
+ if (not perm)
+ return false;
+
+ std::vector<uint32_t> axes_after_reduction;
+ std::vector<uint32_t> orig_reduced_axes;
+ for (uint32_t axis = 0; axis < perm->size<loco::DataType::S32>(); ++axis)
+ {
+ uint32_t original_axis = static_cast<uint32_t>(perm->at<loco::DataType::S32>(axis));
+
+ if (const_has_value_s32(rindices, axis))
+ {
+ orig_reduced_axes.push_back(original_axis);
+ continue;
+ }
+
+ axes_after_reduction.push_back(original_axis);
+ }
+
+ if (not std::is_sorted(axes_after_reduction.begin(), axes_after_reduction.end()))
+ return false;
+
+ auto fused_rindices = create_fused_indices(rindices, orig_reduced_axes);
+ if (not fused_rindices)
+ return false;
+
+ // Create and configure new CircleMean operation.
+ auto fused_mean = mean->graph()->nodes()->create<luci::CircleMean>();
+ fused_mean->reduction_indices(fused_rindices);
+ fused_mean->input(transpose->a());
+ fused_mean->keep_dims(false);
+ fused_mean->name(mean->name() + "/Transpose");
+
+ // Replace old CircleMean operation with new CircleMean operation with merged indices.
+ replace(mean).with(fused_mean);
+ luci::add_origin(fused_mean,
+ luci::composite_origin({luci::get_origin(mean), luci::get_origin(transpose)}));
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseTransposeWithMeanPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto mean = dynamic_cast<luci::CircleMean *>(node);
+ if (not mean)
+ continue;
+
+ if (fuse_transpose_with_mean(mean))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseTransposeWithMeanPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ * |
+ * [CircleTranspose, perm<0, 2, 3, 1>]
+ * |
+ * [CircleMean, axis<3>]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleMean, axis<1>] [CircleTranspose, perm<0, 2, 3, 1>]
+ * | |
+ * [CircleMean, axis<3>]
+ *
+ */
+class FuseTransposeWithMeanTestGraph : public TestIOGraph
+{
+public:
+ FuseTransposeWithMeanTestGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1, 64, 20, 32}, {1, 20, 32});
+
+ _mean = g()->nodes()->create<luci::CircleMean>();
+ _transpose = g()->nodes()->create<luci::CircleTranspose>();
+ _indices = g()->nodes()->create<luci::CircleConst>();
+ _perm = g()->nodes()->create<luci::CircleConst>();
+
+ _mean->name("mean");
+ _transpose->name("transpose");
+ _indices->name("indices");
+ _perm->name("perm");
+
+ _indices->rank(1);
+ _indices->dtype(loco::DataType::S32);
+ _indices->size<loco::DataType::S32>(1);
+ _indices->at<loco::DataType::S32>(0) = static_cast<int32_t>(3);
+ _indices->dim(0) = 1;
+ _indices->shape_status(luci::ShapeStatus::VALID);
+
+ _perm->rank(1);
+ _perm->dtype(loco::DataType::S32);
+ _perm->size<loco::DataType::S32>(4);
+ _perm->dim(0) = 4;
+ _perm->at<loco::DataType::S32>(0) = static_cast<int32_t>(0);
+ _perm->at<loco::DataType::S32>(1) = static_cast<int32_t>(2);
+ _perm->at<loco::DataType::S32>(2) = static_cast<int32_t>(3);
+ _perm->at<loco::DataType::S32>(3) = static_cast<int32_t>(1);
+ _perm->shape_status(luci::ShapeStatus::VALID);
+
+ _transpose->a(input());
+ _transpose->perm(_perm);
+
+ _mean->input(_transpose);
+ _mean->reduction_indices(_indices);
+
+ output()->from(_mean);
+ }
+
+ luci::CircleTranspose *transpose(void) const { return _transpose; }
+ luci::CircleMean *mean(void) const { return _mean; }
+
+private:
+ luci::CircleTranspose *_transpose = nullptr;
+ luci::CircleMean *_mean = nullptr;
+ luci::CircleConst *_indices = nullptr;
+ luci::CircleConst *_perm = nullptr;
+};
+
+} // namespace
+
+TEST(FuseTransposeWithMeanPassTest, name)
+{
+ luci::FuseTransposeWithMeanPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(FuseTransposeWithMeanPassTest, fuse_transpose_with_mean)
+{
+ FuseTransposeWithMeanTestGraph g;
+ luci::FuseTransposeWithMeanPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ auto fused_mean = dynamic_cast<luci::CircleMean *>(g.output()->from());
+ EXPECT_NE(nullptr, fused_mean);
+
+ auto rindices = dynamic_cast<luci::CircleConst *>(fused_mean->reduction_indices());
+ EXPECT_NE(nullptr, rindices);
+
+ EXPECT_EQ(1, rindices->rank());
+ EXPECT_EQ(1, rindices->dim(0));
+ EXPECT_EQ(1, rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, rindices->at<loco::DataType::S32>(0));
+}
+
+TEST(FuseTransposeWithMeanPassTest, fuse_transpose_with_mean_NEG)
+{
+ FuseTransposeWithMeanTestGraph g;
+ luci::FuseTransposeWithMeanPass pass;
+
+ g.init();
+
+ // Add CircleRelu operation between CircleMean and Transpose
+ auto relu = g.g()->nodes()->create<luci::CircleRelu>();
+ relu->name("relu");
+ relu->features(g.transpose());
+ g.mean()->input(relu);
+
+ // Due to the CircleRelu operation, pass will not be applied
+ EXPECT_FALSE(pass.run(g.g()));
+}
return false;
auto input_node = loco::must_cast<luci::CircleNode *>(input);
- return copy_qparam(node, input_node);
+ return copy_qparam(input_node, node);
}
// TODO : Add more Ops (e.g., Transpose)
while (pass.run(&g.g))
;
- EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]);
- EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]);
- EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]);
- EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]);
- EXPECT_EQ(0, g.conv->quantparam()->zerop[1]);
- EXPECT_EQ(10, g.conv->quantparam()->zerop[2]);
+ EXPECT_FLOAT_EQ(0.1, g.reshape->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.reshape->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.reshape->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.reshape->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.reshape->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.reshape->quantparam()->zerop[2]);
}
TEST(PropagateQuantParam, wrong_op_NEG)
scaling_factor = scale_factor_from_min_side > scale_factor_from_max_side
? scale_factor_from_min_side
: scale_factor_from_max_side;
+
+ // protect scale from being very low to avoid overflow/underflow
+ if (scaling_factor < 1e-9)
+ scaling_factor = 1e-9;
+
zp = 0;
nudged_min = static_cast<float>(qmin_double * scaling_factor);
nudged_max = static_cast<float>(qmax_double * scaling_factor);
void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataType quant_type);
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type);
+
bool is_weights(CircleNode *node);
bool is_quantized(const CircleNode *node);
return new_node;
}
-void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *target)
+void overwrite_quantparam(luci::CircleNode *source, luci::CircleNode *target)
{
- auto concat_qparam = concat->quantparam();
- if (concat_qparam == nullptr)
- throw std::runtime_error("quantparam of concat is not found during overwrite");
+ auto source_qparam = source->quantparam();
+ if (source_qparam == nullptr)
+ throw std::runtime_error("source quantparam is not found during overwrite");
auto target_qparam = target->quantparam();
if (target_qparam == nullptr)
if (target_qparam == nullptr)
throw std::runtime_error("Creating new quant param failed");
}
- target_qparam->min = concat_qparam->min;
- target_qparam->max = concat_qparam->max;
- target_qparam->scale = concat_qparam->scale;
- target_qparam->zerop = concat_qparam->zerop;
- target_qparam->quantized_dimension = concat_qparam->quantized_dimension;
+ target_qparam->min = source_qparam->min;
+ target_qparam->max = source_qparam->max;
+ target_qparam->scale = source_qparam->scale;
+ target_qparam->zerop = source_qparam->zerop;
+ target_qparam->quantized_dimension = source_qparam->quantized_dimension;
}
void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
std::vector<int32_t> quantized_values(size);
for (uint32_t i = 0; i < size; ++i)
{
- auto data = const_node->at<loco::DataType::FLOAT32>(i);
- quantized_values[i] = static_cast<int32_t>(std::round(data * scaling_factor_inv) + zerop);
+ auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
+ double quantized_float = std::round(data * scaling_factor_inv) + zerop;
+ constexpr auto int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
+ constexpr auto int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
+ quantized_float = std::min(int_max, std::max(int_min, quantized_float));
+
+ quantized_values[i] = static_cast<int32_t>(quantized_float);
}
switch (quant_type)
"fully-connected layer have bias");
}
+void set_act_qparam(luci::CircleNode *node, float scale, int64_t zp)
+{
+ assert(node); // FIX_CALLER_UNLESS
+ assert(node->quantparam()); // FIX_CALLER_UNLESS
+
+ auto qparam = node->quantparam();
+ assert(qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+ assert(qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
+ qparam->scale[0] = scale;
+ qparam->zerop[0] = zp;
+}
+
+/**
+ * @brief Manually set scale/zp of output tensor of special Ops
+ */
+struct QuantizeSpecialActivation final : public luci::CircleNodeMutableVisitor<void>
+{
+ QuantizeSpecialActivation(loco::DataType input, loco::DataType output)
+ : input_type(input), output_type(output)
+ {
+ }
+
+ loco::DataType input_type;
+ loco::DataType output_type;
+
+ void visit(luci::CircleNode *)
+ {
+ // Do nothing by default
+ }
+
+ void visit(luci::CircleLogistic *node)
+ {
+ if (output_type == loco::DataType::U8)
+ set_act_qparam(node, 1.0f / 256.0f, 0);
+ else
+ {
+ assert(output_type == loco::DataType::S16);
+ set_act_qparam(node, 1.0f / 32768.0f, 0);
+ }
+ }
+
+ void visit(luci::CircleTanh *node)
+ {
+ if (output_type == loco::DataType::U8)
+ set_act_qparam(node, 2.0f / 256.0f, 128);
+ else
+ {
+ assert(output_type == loco::DataType::S16);
+ set_act_qparam(node, 1.0f / 32768.0f, 0);
+ }
+ }
+
+ void visit(luci::CircleStridedSlice *node)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ auto i_qparam = input->quantparam();
+ assert(i_qparam);
+ assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+ assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
+ auto i_scale = i_qparam->scale[0];
+ auto i_zp = i_qparam->zerop[0];
+
+ set_act_qparam(node, i_scale, i_zp);
+ }
+
+ void visit(luci::CircleSplitOut *node)
+ {
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ auto i_qparam = input->quantparam();
+ assert(i_qparam);
+ assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+ assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
+ auto i_scale = i_qparam->scale[0];
+ auto i_zp = i_qparam->zerop[0];
+
+ set_act_qparam(node, i_scale, i_zp);
+ }
+
+ void visit(luci::CircleUnpackOut *node)
+ {
+ auto unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(unpack->value());
+ auto i_qparam = input->quantparam();
+ assert(i_qparam);
+ assert(i_qparam->scale.size() == 1); // FIX_CALLER_UNLESS
+ assert(i_qparam->zerop.size() == 1); // FIX_CALLER_UNLESS
+ auto i_scale = i_qparam->scale[0];
+ auto i_zp = i_qparam->zerop[0];
+
+ set_act_qparam(node, i_scale, i_zp);
+ }
+
+ // TODO Move Softmax, Floor, Ceil from QuantizeActivation to here
+};
+
/**
* @brief QuantizeActivation quantizes tensors for activations
* @details Quantize using recorded min/max values
{
// Quantize using recorded min/max
auto quantparam = circle_node->quantparam();
+ assert(quantparam);
assert(quantparam->min.size() == 1); // only support layer-wise quant
assert(quantparam->max.size() == 1); // only support layer-wise quant
auto min = quantparam->min[0];
auto max = quantparam->max[0];
+ // Special values
+ if (circle_node->opcode() == luci::CircleOpcode::SOFTMAX)
+ {
+ min = 0.0f;
+ max = 1.0f;
+ }
+
float scaling_factor{0};
int64_t zp{0};
float nudged_min{0};
circle_node->dtype(loco::DataType::S16);
}
+ // Nodes fused with activation functions which need special quantization
+ auto fused_act_node =
+ dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(circle_node);
+ if (fused_act_node != nullptr &&
+ fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
+ {
+ if (output_type == loco::DataType::U8)
+ {
+ scaling_factor = 2.0f / 256.0f;
+ zp = 128;
+ }
+ else
+ {
+ assert(output_type == loco::DataType::S16);
+ scaling_factor = 1.0f / 32768.0f;
+ zp = 0;
+ }
+ }
+
+ // The output of these Ops should be integer, so scale should be integer
+ // TODO Handle cases where the integer scale needs to be propagated
+ if (circle_node->opcode() == CircleOpcode::FLOOR ||
+ circle_node->opcode() == CircleOpcode::FLOOR_DIV ||
+ circle_node->opcode() == CircleOpcode::FLOOR_MOD ||
+ circle_node->opcode() == CircleOpcode::CEIL)
+ {
+ assert(scaling_factor >= 0); // FIX_ME_UNLESS
+ scaling_factor = scaling_factor < 1 ? 1.0f : std::round(scaling_factor);
+ }
+
circle_node->quantparam()->min.clear();
circle_node->quantparam()->max.clear();
circle_node->quantparam()->scale.push_back(scaling_factor);
circle_node->quantparam()->zerop.push_back(zp);
}
+ // Fix special attributes
+ if (circle_node->opcode() == luci::CircleOpcode::CAST)
+ {
+ auto *cast = loco::must_cast<luci::CircleCast *>(circle_node);
+ auto *cast_input = loco::must_cast<luci::CircleNode *>(cast->x());
+
+ // make sure that cast_input is already quantized
+ assert(cast_input->dtype() != loco::DataType::FLOAT32);
+ cast->in_data_type(cast_input->dtype());
+ cast->out_data_type(cast->dtype());
+ }
}
return false;
}
auto const_bias = loco::must_cast<luci::CircleConst *>(node);
assert(const_bias->dtype() == loco::DataType::FLOAT32);
+ // If input is const, it is quantized here, not in QuantizeActivation
+ if (auto const_input = dynamic_cast<luci::CircleConst *>(input))
+ {
+ quant_const(const_input, output_type);
+ }
+
CircleConst *new_bias = nullptr;
if (granularity == QuantizationGranularity::ChannelWise)
{
- assert(input->quantparam()->scale.size() == 1); // input scale's layer-wise
- auto input_scale = input->quantparam()->scale[0];
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // input scale's layer-wise
+ auto input_scale = input_q->scale[0];
assert(weight->quantparam() != nullptr); // weight scale's channel-wise
auto weight_scale = weight->quantparam()->scale;
}
else
{
- assert(input->quantparam()->scale.size() == 1); // Only support per-layer quant
- auto input_scale = input->quantparam()->scale[0];
+ auto input_q = input->quantparam();
+ assert(input_q);
+ assert(input_q->scale.size() == 1); // Only support per-layer quant
+ auto input_scale = input_q->scale[0];
- assert(weight->quantparam()->scale.size() == 1); // Only support per-layer quant
- auto weight_scale = weight->quantparam()->scale[0];
+ auto weight_q = weight->quantparam();
+ assert(weight_q);
+ assert(weight_q->scale.size() == 1); // Only support per-layer quant
+ auto weight_scale = weight_q->scale[0];
float scaling_factor{0};
int64_t zp{0};
bool visit(luci::CircleNode *) { return false; }
};
+/** EXAMPLE
+ *
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * (qparam1) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (qparam2) (qparam2) (FP32)
+ * \ /
+ * \ /
+ * [CirclePack]
+ * (qparam2)
+ *
+ * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
+ */
+void propagate_pack_quantparam(luci::CirclePack *pack, loco::DataType quant_type)
+{
+ assert(pack->quantparam() != nullptr);
+
+ const auto num_inputs = pack->values_count();
+
+ for (uint32_t i = 0; i < num_inputs; i++)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
+
+ // Skip if this input is PACK Op
+ if (node->opcode() == luci::CircleOpcode::PACK)
+ continue;
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of pack Op");
+
+ const auto pack_qparam = pack->quantparam();
+ if (pack_qparam == nullptr)
+ throw std::runtime_error("quantparam of pack is not found during propagation");
+
+ assert(pack_qparam->scale.size() == 1);
+ assert(pack_qparam->zerop.size() == 1);
+ const auto scaling_factor = pack_qparam->scale[0];
+ const auto zerop = pack_qparam->zerop[0];
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, quant_type);
+ pack->values(i, new_const);
+ overwrite_quantparam(pack, new_const);
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pack, node);
+ }
+ }
+}
+
/**
* @brief Quantize const input tensors using min/max of const values
*/
case luci::CircleOpcode::ARG_MAX:
case luci::CircleOpcode::ARG_MIN:
case luci::CircleOpcode::BATCH_TO_SPACE_ND:
+ case luci::CircleOpcode::LOCAL_RESPONSE_NORMALIZATION:
case luci::CircleOpcode::MEAN:
+ case luci::CircleOpcode::MIRROR_PAD:
case luci::CircleOpcode::PAD:
case luci::CircleOpcode::REDUCE_ANY:
case luci::CircleOpcode::REDUCE_PROD:
case luci::CircleOpcode::SQRT:
case luci::CircleOpcode::SUB:
case luci::CircleOpcode::TANH:
+ case luci::CircleOpcode::UNPACK:
// Quantize all const inputs using their values
for (uint32_t i = 0; i < arity; i++)
{
quant_const(const_node, output_type);
break;
+ case luci::CircleOpcode::PADV2:
+ // First and third constant inputs are quantized
+ // Second input should not be quantized (e.g., paddings)
+ // Quant params are propagated either from output range to the non-constant input
+ // or from input to output and constant values
+ propagate_pad_v2_quantparam(loco::must_cast<CirclePadV2 *>(node), output_type);
+ break;
+
+ case luci::CircleOpcode::PACK:
+ // Quant param is propagated from output to inputs
+ propagate_pack_quantparam(loco::must_cast<CirclePack *>(node), output_type);
+ break;
+
default:
for (uint32_t i = 0; i < arity; i++)
{
}
}
+/**
+ * tells if pad_v2 quantization should ignore padding value
+ * In that case padding const will be quantized with input parameters, and probably clipped
+ */
+bool ignore_pad_v2_const_quantization(luci::CirclePadV2 *pad)
+{
+ // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
+ // TODO use metadata hints to detect this case
+ auto const_value_node = dynamic_cast<luci::CircleConst *>(pad->arg(2));
+ if (!const_value_node)
+ return false;
+ if (const_value_node->dtype() == loco::DataType::FLOAT32)
+ {
+ float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
+ if (const_value == std::numeric_limits<float>::lowest())
+ return true;
+ }
+ return false;
+}
+
+/** BEFORE
+ *
+ * [CircleNode] [CircleConst] [CircleConst]
+ * (U8 qparam1) (S32) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 1)
+ *
+ * By default qparam is propagated from output to inputs to meet backend requirements.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (S32) (U8 qparam2) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam2)
+ *
+ * AFTER (case 2)
+ *
+ * In case padded value is the lowest float value
+ * Qparam is propagated from input to output and constant.
+ *
+ * This is a special case for optimization constructed pad, needed to guarantee that
+ * extremely large negative constant do not stretch output quantization range.
+ *
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam1) (S32) (U8 qparam1) (FP32)
+ * \ | /
+ * \ | /
+ * [CirclePadV2]
+ * (U8 qparam1)
+ */
+void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2, loco::DataType quant_type)
+{
+ if (ignore_pad_v2_const_quantization(pad_v2))
+ {
+ // propagate input quantization paramters from input to output and padding const value
+ auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
+ overwrite_quantparam(pad_v2_input, pad_v2);
+
+ auto const_value_node = dynamic_cast<luci::CircleConst *>(pad_v2->arg(2));
+ auto new_const = luci::clone(const_value_node);
+
+ const auto pad_v2_input_qparam = pad_v2_input->quantparam();
+ assert(pad_v2_input_qparam != nullptr);
+ assert(pad_v2_input_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
+ const auto zerop = pad_v2_input_qparam->zerop.at(0);
+
+ quant_const_values(new_const, scaling_factor, zerop, quant_type);
+ overwrite_quantparam(pad_v2_input, new_const);
+ pad_v2->constant_values(new_const);
+ return;
+ }
+
+ // Propagate quantization paramters from output to inputs,
+ // to fit both input and counstant_value in one quant range.
+ auto quant_input = [pad_v2, quant_type](void (CirclePadV2::*arg_setter)(loco::Node *),
+ uint32_t arg) {
+ auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
+
+ // Quantize constant values
+ if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
+ if (is_quantized(const_node))
+ return;
+
+ if (const_node->dtype() != loco::DataType::FLOAT32)
+ throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
+
+ const auto pad_v2_qparam = pad_v2->quantparam();
+ if (pad_v2_qparam == nullptr)
+ throw std::runtime_error("quantparam of PadV2 is not found during propagation");
+
+ assert(pad_v2_qparam->scale.size() == 1);
+ const auto scaling_factor = pad_v2_qparam->scale.at(0);
+ const auto zerop = pad_v2_qparam->zerop.at(0);
+
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, quant_type);
+ overwrite_quantparam(pad_v2, new_const);
+ (pad_v2->*arg_setter)(new_const);
+ }
+ // Subsequent PadV2 Ops quant params are not propagated
+ else if (node->opcode() == luci::CircleOpcode::PADV2)
+ {
+ return;
+ }
+ else
+ {
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ return;
+
+ // Non-const input must have been quantized
+ assert(node->quantparam() != nullptr);
+ overwrite_quantparam(pad_v2, node);
+ }
+ };
+
+ quant_input(&CirclePadV2::input, 0);
+ quant_input(&CirclePadV2::constant_values, 2);
+}
+
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
LOGGER(l);
quantize_const_inputs(circle_node, _output_dtype);
}
+ // Update qparam of output of special Ops
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ QuantizeSpecialActivation qsa(_input_dtype, _output_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&qsa);
+ }
+
// Update output dtype
auto graph_outputs = g->outputs();
for (auto node : loco::output_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ auto node_name = [&circle_node]() {
+ if (circle_node->name().length() == 0)
+ return std::string("(noname)");
+
+ return circle_node->name();
+ };
+
// Verify Type
if (_quantized_dtype == Type::U8)
{
VerifyQuantizedNodeU8Type vt;
if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type");
+ throw std::runtime_error("Wrong data type detected in " + node_name());
}
else if (_quantized_dtype == Type::S16)
{
VerifyQuantizedNodeS16Type vt;
if (!circle_node->accept(&vt))
- throw std::runtime_error("Wrong data type");
+ throw std::runtime_error("Wrong data type detected in " + node_name());
}
// Verify Granularity
{
VerifyQuantizedNodeLayerWiseGranularity vg;
if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity");
+ throw std::runtime_error("Wrong granularity detected in " + node_name());
}
else if (_granularity == Granularity::ChannelWise)
{
VerifyQuantizedNodeChannelWiseGranularity vg;
if (!circle_node->accept(&vg))
- throw std::runtime_error("Wrong granularity");
+ throw std::runtime_error("Wrong granularity detected in " + node_name());
}
}
}
verifier.verify(g->g());
}
+void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, Type wrong_dtype,
+ luci::CircleNode *target)
+{
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
+ pass.run(g->g());
+
+ target->dtype(wrong_dtype);
+
+ luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ verifier.verify(g->g());
+}
+
// Helper function to reduce duplicate test codes
// Assumption: g->output()->from() is the target node
void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
loco::Node *gamma(void) const { return _instnorm->gamma(); }
loco::Node *beta(void) const { return _instnorm->beta(); }
-public:
+private:
luci::CircleInstanceNorm *_instnorm = nullptr;
luci::CircleConst *_input = nullptr;
luci::CircleConst *_gamma = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleLogistic *_logistic = nullptr;
};
+class LocalResponseNormalizationTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({1, 2, 2, 32}, {1, 2, 2, 32});
+ _lrn = g()->nodes()->create<luci::CircleLocalResponseNormalization>();
+ {
+ _lrn->input(input());
+ }
+ output()->from(_lrn);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleLocalResponseNormalization *_lrn = nullptr;
+};
+
class SoftmaxTestGraph final : public SimpleTestGraph
{
public:
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSoftmax *_softmax = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSpaceToBatchND *_stob = nullptr;
luci::CircleConst *_block_shape = nullptr;
luci::CircleConst *_paddings = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSpaceToDepth *_stod = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSlice *_slice = nullptr;
luci::CircleConst *_begin = nullptr;
luci::CircleConst *_size = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSplit *_split = nullptr;
luci::CircleSplitOut *_split_o1 = nullptr;
luci::CircleConst *_split_dim = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleStridedSlice *_slice = nullptr;
luci::CircleConst *_begin = nullptr;
luci::CircleConst *_end = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleReshape *_reshape = nullptr;
luci::CircleConst *_shape = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleTanh *_tanh = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleFloor *_floor = nullptr;
};
}
public:
+ // NOTE: Do not override `luci::CircleNode* input(void)` incidentally
+ loco::Node *input_argmax(void) { return _argmax->input(); }
+ loco::Node *dimension(void) { return _argmax->dimension(); }
+
+private:
luci::CircleArgMax *_argmax = nullptr;
luci::CircleConst *_dimension = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleBatchToSpaceND *_btos = nullptr;
luci::CircleConst *_block_shape = nullptr;
luci::CircleConst *_crops = nullptr;
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleDepthToSpace *_dtos = nullptr;
};
+class PackTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({16}, {32});
+ _param = create_dummy_const<Type::FLOAT32>(g(), {16});
+ _pack = g()->nodes()->create<luci::CirclePack>(2);
+ {
+ _pack->values(0, input());
+ _pack->values(1, _param);
+ _pack->axis(0);
+ }
+ output()->from(_pack);
+
+ set_minmax_to_non_const(g(), -1, 1);
+
+ // Set min/max of the input
+ // pack's qparam will be propagted, overwritten to the input
+ auto input = loco::must_cast<luci::CircleNode *>(pack()->values(0));
+ auto qp = input->quantparam();
+ qp->min[0] = -0.5;
+ qp->max[0] = 0.5;
+ }
+
+public:
+ luci::CirclePack *pack(void) { return _pack; }
+
+private:
+ luci::CirclePack *_pack = nullptr;
+ luci::CircleConst *_param = nullptr;
+};
+
class PadTestGraph final : public SimpleTestGraph
{
public:
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CirclePad *_pad = nullptr;
luci::CircleConst *_paddings = nullptr;
};
+class PadV2TestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _paddings = g()->nodes()->create<luci::CircleConst>();
+ {
+ _paddings->dtype(Type::S32);
+ }
+ _constant_values = create_dummy_const<Type::FLOAT32>(g(), {1});
+ _pad = g()->nodes()->create<luci::CirclePadV2>();
+ {
+ _pad->input(input());
+ _pad->paddings(_paddings);
+ _pad->constant_values(_constant_values);
+ }
+ output()->from(_pad);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CirclePadV2 *_pad = nullptr;
+ luci::CircleConst *_paddings = nullptr;
+ luci::CircleConst *_constant_values = nullptr;
+};
+
+class MirrorPadTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _paddings = g()->nodes()->create<luci::CircleConst>();
+ {
+ _paddings->dtype(Type::S32);
+ }
+ _constant_values = create_dummy_const<Type::FLOAT32>(g(), {1});
+ _mirror_pad = g()->nodes()->create<luci::CircleMirrorPad>();
+ {
+ _mirror_pad->input(input());
+ _mirror_pad->paddings(_paddings);
+ _mirror_pad->mode(luci::MirrorPadMode::REFLECT);
+ }
+ output()->from(_mirror_pad);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleMirrorPad *_mirror_pad = nullptr;
+ luci::CircleConst *_paddings = nullptr;
+ luci::CircleConst *_constant_values = nullptr;
+};
+
class TransposeTestGraph final : public SimpleTestGraph
{
public:
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleTranspose *_transpose = nullptr;
luci::CircleConst *_perm = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleConcatenation *_concat = nullptr;
luci::CircleConst *_param = nullptr;
};
loco::Node *x(void) const { return _op->x(); }
loco::Node *y(void) const { return _op->y(); }
-public:
+private:
Op *_op = nullptr;
luci::CircleConst *_y = nullptr;
};
loco::Node *x(void) const { return _op->x(); }
loco::Node *y(void) const { return _op->y(); }
-public:
+private:
Op *_op = nullptr;
luci::CircleConst *_y = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleRsqrt *_rsqrt = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleSqrt *_sqrt = nullptr;
};
set_minmax_to_non_const(g(), -1, 1);
}
-public:
+private:
luci::CircleElu *_elu = nullptr;
};
luci::CircleConst *_size = nullptr;
};
+class ResizeNearestNeighborTestGraph final : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1, 4, 4, 1}, {1, 8, 8, 1});
+
+ _size = create_const<Type::S32, int32_t>(g(), {2}, {8, 8});
+ _resize_nearest_neighbor = g()->nodes()->create<luci::CircleResizeNearestNeighbor>();
+ {
+ _resize_nearest_neighbor->input(input());
+ _resize_nearest_neighbor->size(_size);
+ }
+ output()->from(_resize_nearest_neighbor);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleResizeNearestNeighbor *_resize_nearest_neighbor = nullptr;
+ luci::CircleConst *_size = nullptr;
+};
+
+class UnpackTestGraph final : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1, 32}, {32});
+ _unpack = g()->nodes()->create<luci::CircleUnpack>();
+ {
+ _unpack->value(input());
+ _unpack->axis(0);
+ _unpack->num(1);
+ }
+ _unpack_o1 = g()->nodes()->create<luci::CircleUnpackOut>();
+ {
+ _unpack_o1->input(_unpack);
+ _unpack_o1->index(0);
+ }
+
+ output()->from(_unpack_o1);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleUnpack *_unpack = nullptr;
+ luci::CircleUnpackOut *_unpack_o1 = nullptr;
+ luci::CircleConst *_unpack_dim = nullptr;
+};
+
} // namespace
// Quantize and verify with given configurations
EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity)); \
} while (0)
+// Quantize and verify with wrong type
+// Users can specify the test target
+#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ EXPECT_ANY_THROW( \
+ quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype, node)); \
+ } while (0)
+
// Quantize and verify with wrong granularity
// Users can specify the test target
#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, LocalResponseNormalization)
+{
+ TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, LocalResponseNormalization_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(LocalResponseNormalizationTestGraph, Type::U8, Granularity::LayerWise,
+ Type::S16);
+ TEST_WITH_WRONG_TYPE(LocalResponseNormalizationTestGraph, Type::U8, Granularity::ChannelWise,
+ Type::S16);
+ TEST_WITH_WRONG_TYPE(LocalResponseNormalizationTestGraph, Type::S16, Granularity::ChannelWise,
+ Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, LocalResponseNormalization_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(LocalResponseNormalizationTestGraph, Type::U8,
+ Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(LocalResponseNormalizationTestGraph, Type::U8,
+ Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(LocalResponseNormalizationTestGraph, Type::S16,
+ Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Logistic)
{
TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise);
SUCCEED();
}
-TEST(QuantizedModelVerifierTest, ArgMax_wrong_dimension_type_NEG)
+TEST(QuantizedModelVerifierTest, ArgMax_wrong_input_type_NEG)
{
- ArgMaxTestGraph<Type::S32> g;
- g.init();
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise);
- pass.run(g.g());
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8);
- g._dimension->dtype(Type::U8);
-
- luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise);
- EXPECT_ANY_THROW(verifier.verify(g.g()));
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ArgMaxTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
}
-TEST(QuantizedModelVerifierTest, ArgMax_wrong_input_granularity_NEG)
+TEST(QuantizedModelVerifierTest, ArgMax_wrong_dimension_type_NEG)
{
- ArgMaxTestGraph<Type::S32> g;
- g.init();
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise,
+ Type::S16, g.dimension());
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise,
+ Type::S16, g.dimension());
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise,
+ Type::U8, g.dimension());
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise);
- pass.run(g.g());
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::LayerWise,
+ Type::S16, g.dimension());
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise,
+ Type::S16, g.dimension());
+ TEST_WITH_WRONG_TYPE_TARGET(ArgMaxTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise,
+ Type::U8, g.dimension());
+ SUCCEED();
+}
- insert_scale_zp(loco::must_cast<luci::CircleNode *>(g._argmax->input()), 1.0, 1);
+TEST(QuantizedModelVerifierTest, ArgMax_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise,
+ g.input_argmax());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise,
+ g.input_argmax());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S32>, Type::S16,
+ Granularity::ChannelWise, g.input_argmax());
- luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise);
- EXPECT_ANY_THROW(verifier.verify(g.g()));
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::LayerWise,
+ g.input_argmax());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise,
+ g.input_argmax());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ArgMaxTestGraph<Type::S64>, Type::S16,
+ Granularity::ChannelWise, g.input_argmax());
+ SUCCEED();
}
TEST(QuantizedModelVerifierTest, BatchToSpaceND)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Pack)
+{
+ TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(PackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(PackTestGraph, Type::S16, Granularity::ChannelWise);
+
+ // Test if Pack's qparam is propagated to the input
+ {
+ PackTestGraph g;
+ g.init();
+ quantize_and_verify(g.g(), Type::U8, Granularity::ChannelWise);
+ auto input = loco::must_cast<luci::CircleNode *>(g.pack()->values(0));
+ auto qp = input->quantparam();
+ EXPECT_FLOAT_EQ(2.0 / 255.0, qp->scale[0]);
+ EXPECT_FLOAT_EQ(128, qp->zerop[0]);
+ }
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pack_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(PackTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PackTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PackTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pack_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(PackTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Pad)
{
TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, PadV2)
+{
+ TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, PadV2_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(PadV2TestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PadV2TestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PadV2TestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, PadV2_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(PadV2TestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(PadV2TestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(PadV2TestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, MirrorPad)
+{
+ TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, MirrorPad_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(MirrorPadTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, MirrorPad_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(MirrorPadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(MirrorPadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(MirrorPadTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Transpose)
{
TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise);
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor)
+{
+ TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise,
+ Type::S16);
+ TEST_WITH_WRONG_TYPE(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise,
+ Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ResizeNearestNeighbor_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(ResizeNearestNeighborTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(ResizeNearestNeighborTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(ResizeNearestNeighborTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Unpack)
+{
+ TEST_WITH_GRAPH(UnpackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(UnpackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(UnpackTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Unpack_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(UnpackTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(UnpackTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(UnpackTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Unpack_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(UnpackTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(UnpackTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(UnpackTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
#undef TEST_WITH_GRAPH
#undef TEST_WITH_WRONG_TYPE
#undef TEST_WITH_WRONG_GRANULARITY
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveFakeQuantPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+void remove_fake_quant(luci::CircleFakeQuant *fakequant)
+{
+ assert(fakequant != nullptr);
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(fakequant->inputs());
+
+ replace(fakequant).with(input_node);
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleFakeQuant]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleNode] [CircleFakeQuant]
+ *
+ * CircleFakeQuant OP will be removed from the output graph
+ */
+bool RemoveFakeQuantPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto target_node = dynamic_cast<luci::CircleFakeQuant *>(node);
+ if (target_node != nullptr)
+ {
+ remove_fake_quant(target_node);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveFakeQuantPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class FakeQuantGraphlet
+{
+public:
+ FakeQuantGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _fq = g->nodes()->create<luci::CircleFakeQuant>();
+ _fq->name("fq");
+ }
+
+protected:
+ luci::CircleFakeQuant *_fq = nullptr;
+};
+
+class FakeQuantGraph : public TestIOGraph, public FakeQuantGraphlet
+{
+public:
+ FakeQuantGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ FakeQuantGraphlet::init(g());
+
+ _fq->inputs(input());
+
+ output()->from(_fq);
+ }
+};
+
+} // namespace
+
+TEST(RemoveFakeQuantPass, name)
+{
+ luci::RemoveFakeQuantPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveFakeQuantPass, remove_fakequant)
+{
+ FakeQuantGraph g;
+ luci::RemoveFakeQuantPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ auto *node1 = loco::must_cast<luci::CircleNode *>(g.output()->from());
+ auto *node2 = loco::must_cast<luci::CircleNode *>(g.input());
+ EXPECT_EQ(node1, node2);
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveQuantDequantSeqPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool remove_quant_dequant(luci::CircleDequantize *dequant)
+{
+ assert(dequant != nullptr);
+
+ auto quantize = dynamic_cast<luci::CircleQuantize *>(dequant->input());
+ if (quantize == nullptr)
+ return false;
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(quantize->input());
+
+ replace(dequant).with(input_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleQuantize]
+ * |
+ * [CircleDequantize]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleQuantize]
+ * | |
+ * [CircleNode] [CircleDequantize]
+ *
+ * CircleQuant-CircleDequant sequance will be removed from the output graph
+ */
+bool RemoveQuantDequantSeqPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto target_node = dynamic_cast<luci::CircleDequantize *>(node);
+ if (target_node != nullptr)
+ {
+ if (remove_quant_dequant(target_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveQuantDequantSeqPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class QuantDequantGraphlet
+{
+public:
+ QuantDequantGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _qu = g->nodes()->create<luci::CircleQuantize>();
+ _qu->name("qu");
+
+ _de = g->nodes()->create<luci::CircleDequantize>();
+ _de->name("de");
+ }
+
+protected:
+ luci::CircleQuantize *_qu = nullptr;
+ luci::CircleDequantize *_de = nullptr;
+};
+
+class QuantDequantGraph : public TestIOGraph, public QuantDequantGraphlet
+{
+public:
+ QuantDequantGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ QuantDequantGraphlet::init(g());
+
+ _qu->input(input());
+ _de->input(_qu);
+
+ output()->from(_de);
+ }
+};
+
+} // namespace
+
+TEST(RemoveQuantDequantSeqPass, name)
+{
+ luci::RemoveQuantDequantSeqPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveQuantDequantSeqPass, remove_quantdequant)
+{
+ QuantDequantGraph g;
+ luci::RemoveQuantDequantSeqPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ auto *node1 = loco::must_cast<luci::CircleNode *>(g.output()->from());
+ auto *node2 = loco::must_cast<luci::CircleNode *>(g.input());
+ EXPECT_EQ(node1, node2);
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceSubWithAddPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/Nodes/CircleConst.h>
+
+namespace
+{
+
+bool replace_sub_with_const_rhs(luci::CircleSub *sub)
+{
+ auto const_rhs = dynamic_cast<luci::CircleConst *>(sub->y());
+ if (const_rhs == nullptr)
+ return false;
+
+ auto graph = sub->graph();
+
+ auto neg_const_rhs = luci::clone(const_rhs);
+ if (neg_const_rhs->dtype() == loco::DataType::FLOAT32)
+ {
+ for (uint32_t i = 0; i < neg_const_rhs->size<loco::DataType::FLOAT32>(); ++i)
+ neg_const_rhs->at<loco::DataType::FLOAT32>(i) *= -1.0;
+ }
+ else
+ {
+ // TODO Support more data type
+ return false;
+ }
+
+ auto add = graph->nodes()->create<luci::CircleAdd>();
+ add->x(sub->x());
+ add->y(neg_const_rhs);
+ add->name(sub->name());
+ add->fusedActivationFunction(sub->fusedActivationFunction());
+ loco::replace(sub).with(add);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ReplaceSubWithAddPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto sub = dynamic_cast<luci::CircleSub *>(node))
+ {
+ if (replace_sub_with_const_rhs(sub))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceSubWithAddPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [lhs] ------------+
+ * +-- [Sub] --
+ * [rhs_const] ------+
+ *
+ * AFTER
+ *
+ * [lhs] ------------+
+ * +-- [Add] --
+ * [neg_rhs_const] --+
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ lhs = g.nodes()->create<luci::CircleInput>();
+ rhs_const = g.nodes()->create<luci::CircleConst>();
+ sub = g.nodes()->create<luci::CircleSub>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ lhs->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ lhs->dtype(loco::DataType::FLOAT32);
+ rhs_const->dtype(loco::DataType::FLOAT32);
+ sub->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+
+ lhs->shape({1, 3, 4, 5});
+ rhs_const->shape({}); // scalar
+ sub->shape({1, 3, 4, 5});
+ output->shape({1, 3, 4, 5});
+
+ rhs_const->size<loco::DataType::FLOAT32>(1);
+ rhs_const->at<loco::DataType::FLOAT32>(0) = 1.1;
+
+ sub->x(lhs);
+ sub->y(rhs_const);
+ output->from(sub);
+
+ lhs->name("lhs");
+ rhs_const->name("rhs_const");
+ sub->name("sub");
+ output->name("output");
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *lhs = nullptr;
+ luci::CircleConst *rhs_const = nullptr;
+ luci::CircleSub *sub = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(ReplaceSubWithAdd, name)
+{
+ luci::ReplaceSubWithAddPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(ReplaceSubWithAdd, simple)
+{
+ SimpleGraph g;
+
+ luci::ReplaceSubWithAddPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto add = dynamic_cast<luci::CircleAdd *>(g.output->from());
+ EXPECT_NE(nullptr, add);
+
+ auto neg_rhs_const = dynamic_cast<luci::CircleConst *>(add->y());
+ EXPECT_NE(nullptr, neg_rhs_const);
+ EXPECT_EQ(0, neg_rhs_const->rank());
+ EXPECT_FLOAT_EQ(-1.1, neg_rhs_const->at<loco::DataType::FLOAT32>(0));
+}
+
+TEST(ReplaceSubWithAdd, wrong_op_NEG)
+{
+ SimpleGraph g;
+
+ auto mul = g.g.nodes()->create<luci::CircleMul>();
+ mul->x(g.sub->x());
+ mul->y(g.sub->y());
+ loco::replace(g.sub).with(mul);
+
+ luci::ReplaceSubWithAddPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
+
+#include "flatbuffers/flexbuffers.h"
+#include <loco/IR/DataTypeTraits.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <loco.h>
+#include <oops/InternalExn.h>
+
+namespace
+{
+
+template <typename T> std::vector<T> to_vector(const flexbuffers::TypedVector &typed_vec)
+{
+ std::vector<T> answer(typed_vec.size());
+
+ for (uint32_t i = 0; i < answer.size(); ++i)
+ {
+ answer[i] = typed_vec[i].As<T>();
+ }
+
+ return answer;
+}
+
+luci::Padding string_to_padding(const std::string &pad_str)
+{
+ if (pad_str == "VALID")
+ return luci::Padding::VALID;
+ if (pad_str == "SAME")
+ return luci::Padding::SAME;
+
+ return luci::Padding::UNDEFINED;
+}
+
+template <typename NodeT> void set_stride(NodeT *node, const luci::Stride &stride)
+{
+ node->stride()->h(stride.h());
+ node->stride()->w(stride.w());
+}
+
+template <typename NodeT> void set_filter(NodeT *node, const luci::Filter &filter)
+{
+ node->filter()->h(filter.h());
+ node->filter()->w(filter.w());
+}
+
+void init_name_and_origin(luci::CircleNode *node, const std::string &name,
+ const std::shared_ptr<luci::CircleNodeOrigin> &origin)
+{
+ node->name(name);
+ luci::add_origin(node, origin);
+}
+
+template <typename NodeT> NodeT *none_act_func(NodeT *node)
+{
+ node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ return node;
+}
+
+luci::CircleCast *create_cast(luci::CircleNode *input, loco::DataType in_type,
+ loco::DataType out_type)
+{
+ auto cast = input->graph()->nodes()->create<luci::CircleCast>();
+
+ cast->in_data_type(in_type);
+ cast->out_data_type(out_type);
+ cast->dtype(out_type);
+
+ cast->x(input);
+
+ return cast;
+}
+
+template <loco::DataType DT> void fill_conv_weights(luci::CircleConst *weights)
+{
+ assert(weights->rank() == 4);
+
+ auto const kn = weights->dim(0).value();
+ auto const kh = weights->dim(1).value();
+ auto const kw = weights->dim(2).value();
+
+ auto elements_size = kn * kh * kw * 1;
+ weights->size<DT>(elements_size);
+
+ for (uint32_t b = 0; b < kn; ++b)
+ {
+ for (uint32_t y = 0; y < kh; ++y)
+ {
+ for (uint32_t x = 0; x < kw; ++x)
+ {
+ auto const idx = (b * kh + y) * kw + x;
+ weights->at<DT>(idx) = (y * kw + x == b) ? 1 : 0;
+ }
+ }
+ }
+}
+
+luci::CircleConst *create_conv_filter(loco::Graph *graph, const uint32_t kh, const uint32_t kw,
+ const uint32_t kn)
+{
+ auto weights = graph->nodes()->create<luci::CircleConst>();
+
+ weights->dtype(loco::DataType::FLOAT32);
+
+ weights->rank(4);
+ weights->dim(0).set(kn);
+ weights->dim(1).set(kh);
+ weights->dim(2).set(kw);
+ weights->dim(3).set(1);
+ weights->shape_status(luci::ShapeStatus::VALID);
+
+ fill_conv_weights<loco::DataType::FLOAT32>(weights);
+
+ return weights;
+}
+
+template <loco::DataType DT> void fill_zero_bias(luci::CircleConst *bias)
+{
+ assert(bias->rank() == 1);
+
+ auto const depth = bias->dim(0).value();
+
+ bias->size<DT>(depth);
+
+ for (uint32_t i = 0; i < depth; ++i)
+ {
+ bias->at<DT>(i) = 0;
+ }
+}
+
+luci::CircleConst *create_zero_bias(loco::Graph *graph, uint32_t depth)
+{
+ auto bias = graph->nodes()->create<luci::CircleConst>();
+
+ bias->dtype(loco::DataType::FLOAT32);
+
+ bias->rank(1);
+ bias->dim(0).set(depth);
+
+ fill_zero_bias<loco::DataType::FLOAT32>(bias);
+
+ return bias;
+}
+
+luci::CircleConst *create_padding_const(loco::Graph *graph, int32_t left_pad, int32_t right_pad,
+ int32_t top_pad, int32_t bottom_pad)
+{
+ auto paddings = graph->nodes()->create<luci::CircleConst>();
+
+ paddings->dtype(loco::DataType::S32);
+
+ paddings->rank(2);
+ paddings->dim(0).set(4);
+ paddings->dim(1).set(2);
+ paddings->size<loco::DataType::S32>(8);
+ paddings->shape_status(luci::ShapeStatus::VALID);
+
+ paddings->at<loco::DataType::S32>(0) = 0;
+ paddings->at<loco::DataType::S32>(1) = 0;
+
+ paddings->at<loco::DataType::S32>(2) = left_pad;
+ paddings->at<loco::DataType::S32>(3) = right_pad;
+
+ paddings->at<loco::DataType::S32>(4) = top_pad;
+ paddings->at<loco::DataType::S32>(5) = bottom_pad;
+
+ paddings->at<loco::DataType::S32>(6) = 0;
+ paddings->at<loco::DataType::S32>(7) = 0;
+
+ return paddings;
+}
+
+template <loco::DataType DT, typename Numeric>
+luci::CircleConst *create_scalar(loco::Graph *graph, Numeric value)
+{
+ auto scalar = graph->nodes()->create<luci::CircleConst>();
+
+ scalar->dtype(DT);
+
+ scalar->rank(0);
+ scalar->size<DT>(1);
+ scalar->shape_status(luci::ShapeStatus::VALID);
+
+ scalar->scalar<DT>() = value;
+
+ return scalar;
+}
+
+luci::CircleConst *create_shape_tensor(loco::Graph *graph, const std::vector<uint32_t> &dims_vec)
+{
+ auto shape = graph->nodes()->create<luci::CircleConst>();
+
+ shape->dtype(loco::DataType::S32);
+
+ shape->rank(1);
+ shape->dim(0).set(dims_vec.size());
+ shape->shape_status(luci::ShapeStatus::VALID);
+
+ shape->size<loco::DataType::S32>(dims_vec.size());
+
+ for (uint32_t i = 0; i < dims_vec.size(); ++i)
+ {
+ shape->at<loco::DataType::S32>(i) = dims_vec[i];
+ }
+
+ return shape;
+}
+
+int32_t compute_full_padding(int32_t input_size, int32_t output_size, int32_t stride,
+ int32_t filter_size)
+{
+ int32_t effective_input = (output_size - 1) * stride + filter_size;
+ int32_t full = effective_input - input_size;
+ // some extreme cases when part of input was not used in computations
+ if (full < 0)
+ full = 0;
+ return full;
+}
+
+template <loco::DataType DT>
+void fill_coords_addition(luci::Padding padding, const luci::Stride &stride,
+ const luci::Filter &filter, uint32_t input_height, uint32_t input_width,
+ uint32_t depth, luci::CircleConst *cords)
+{
+ assert(cords->rank() == 4);
+
+ auto const output_height = static_cast<int32_t>(cords->dim(1).value());
+ auto const output_width = static_cast<int32_t>(cords->dim(2).value());
+ {
+ auto const element_counts = 1 * output_height * output_width * 1;
+ cords->size<DT>(element_counts);
+ }
+
+ assert(padding != luci::Padding::UNDEFINED);
+
+ // For VALID padding:
+ int32_t start_y = 0;
+ int32_t start_x = 0;
+
+ // For SAME padding:
+ if (padding == luci::Padding::SAME)
+ {
+ start_y = -compute_full_padding(input_height, output_height, stride.h(), filter.h()) / 2;
+ start_x = -compute_full_padding(input_width, output_width, stride.w(), filter.w()) / 2;
+ }
+
+ auto const step_y = static_cast<int32_t>(stride.h());
+ auto const step_x = static_cast<int32_t>(stride.w());
+
+ for (int32_t y_o = 0, y_i = start_y; y_o < output_height; ++y_o, y_i += step_y)
+ {
+ for (int32_t x_o = 0, x_i = start_x; x_o < output_width; ++x_o, x_i += step_x)
+ {
+ auto const output_idx = y_o * output_width + x_o;
+ auto const input_idx = y_i * static_cast<int32_t>(input_width) + x_i;
+
+ // Add small adjustment value to fix cast operation result that follows "coord addition"
+ // in generated subgraph.
+ //
+ // Cast operation discards fractional part of value, so 1.9996 will be transformed to 1
+ // This is not a problem when working with float32, because it represents integers precisely,
+ // but leads to wrong results, when working with quantized numbers.
+ //
+ // This value is larger than quantization error,
+ // and small enough to not affect following computations
+ // (in particular multiplication with depth)
+ const float round_adjustment = 1.0f / (depth + 1);
+
+ cords->at<DT>(output_idx) = input_idx + round_adjustment;
+ }
+ }
+}
+
+luci::CircleConst *create_coords_addition(loco::Graph *graph, luci::Padding padding,
+ const luci::Stride &stride, const luci::Filter &filter,
+ uint32_t input_height, uint32_t input_width,
+ uint32_t depth, uint32_t output_height,
+ uint32_t output_width)
+{
+ auto cords = graph->nodes()->create<luci::CircleConst>();
+
+ cords->dtype(loco::DataType::FLOAT32);
+
+ cords->rank(4);
+ cords->dim(0).set(1);
+ cords->dim(1).set(output_height);
+ cords->dim(2).set(output_width);
+ cords->dim(3).set(1);
+
+ fill_coords_addition<loco::DataType::FLOAT32>(padding, stride, filter, input_height, input_width,
+ depth, cords);
+
+ return cords;
+}
+
+luci::CircleNode *get_custom_output(const luci::CircleCustom *cop, int32_t idx)
+{
+ auto const outputs = loco::succs(cop);
+ assert(outputs.size() == 2);
+
+ auto output = loco::must_cast<luci::CircleCustomOut *>(*outputs.begin());
+ if (output->index() != idx)
+ {
+ output = loco::must_cast<luci::CircleCustomOut *>(*outputs.rbegin());
+ }
+
+ return output;
+}
+
+luci::CircleNode *max_pool_branch(luci::Padding padding, const luci::Stride &stride,
+ const luci::Filter filter, luci::CircleCustom *cop)
+{
+ auto graph = cop->graph();
+ auto input = cop->inputs(0);
+
+ auto origin = luci::get_origin(cop);
+ auto name = cop->name() + "/Argmax";
+
+ // Create MaxPool
+ auto maxpool = none_act_func(graph->nodes()->create<luci::CircleMaxPool2D>());
+ {
+ init_name_and_origin(maxpool, name + "/MaxPool2D", origin);
+
+ set_stride(maxpool, stride);
+ set_filter(maxpool, filter);
+ maxpool->padding(padding);
+
+ maxpool->value(input);
+ }
+
+ return maxpool;
+}
+
+luci::CircleNode *window_flattened_coord(const std::string &name, luci::Padding padding,
+ const luci::Stride &stride, const luci::Filter filter,
+ int32_t input_height, int32_t input_width,
+ uint32_t output_height, uint32_t output_width,
+ luci::CircleNode *input)
+{
+ auto const graph = input->graph();
+ auto const origin = luci::get_origin(input);
+
+ auto const depth_dimension = 3;
+
+ // Create pad in case of SAME padding
+ luci::CircleNode *conv_input = input;
+ if (padding == luci::Padding::SAME)
+ {
+ // Create redundant add to combine two nodes with special quantization restrictions:
+ // PadV2 and Split in this case
+ // TODO Introduce special requantize node and fix quantizer?
+ auto requantize = none_act_func(graph->nodes()->create<luci::CircleMul>());
+ init_name_and_origin(requantize, name + "/Requantize", origin);
+ auto zero_const = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
+ init_name_and_origin(zero_const, name + "Requantize_const", origin);
+
+ requantize->x(input);
+ requantize->y(zero_const);
+
+ auto pad = graph->nodes()->create<luci::CirclePadV2>();
+ init_name_and_origin(pad, name + "/Pad", origin);
+
+ pad->input(requantize);
+
+ int32_t full_w_pad = compute_full_padding(input_width, output_width, stride.w(), filter.w());
+ int32_t full_h_pad = compute_full_padding(input_height, output_height, stride.h(), filter.h());
+ int32_t left_pad = full_w_pad / 2;
+ int32_t right_pad = full_w_pad - left_pad;
+ int32_t top_pad = full_h_pad / 2;
+ int32_t bottom_pad = full_h_pad - top_pad;
+ auto padding_const = create_padding_const(graph, left_pad, right_pad, top_pad, bottom_pad);
+ init_name_and_origin(padding_const, name + "/Pad_shape", origin);
+ pad->paddings(padding_const);
+
+ auto padding_value =
+ create_scalar<loco::DataType::FLOAT32, float>(graph, std::numeric_limits<float>::lowest());
+ init_name_and_origin(padding_value, name + "/Pad_value", origin);
+ pad->constant_values(padding_value);
+
+ conv_input = pad;
+ }
+ // Create Conv2D to move spatial dimensions to depth
+ auto conv = none_act_func(graph->nodes()->create<luci::CircleConv2D>());
+ {
+ init_name_and_origin(conv, name + "/Conv2D", origin);
+
+ // Padding, Stride and kernel size equal to MaxPool's
+ set_stride(conv, stride);
+ conv->padding(luci::Padding::VALID);
+
+ // depth of kernel is equal to square size
+ auto const kh = filter.h();
+ auto const kw = filter.w();
+ auto const kd = kh * kw;
+
+ // use zero bias
+ auto bias = create_zero_bias(graph, kd);
+ init_name_and_origin(bias, conv->name() + "/Bias", origin);
+
+ // create filter
+ // TODO make shared
+ auto weights = create_conv_filter(graph, kh, kw, kd);
+ init_name_and_origin(weights, conv->name() + "/Weights", origin);
+
+ conv->bias(bias);
+ conv->filter(weights);
+ conv->input(conv_input);
+ }
+
+ // Create ArgMax
+ auto argmax = graph->nodes()->create<luci::CircleArgMax>();
+ {
+ init_name_and_origin(argmax, name + "/ArgMax", origin);
+
+ argmax->output_type(loco::DataType::S32);
+
+ // Create argmax_dim
+ auto argmax_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
+ init_name_and_origin(argmax_dim, argmax->name() + "/Dimension", origin);
+
+ argmax->dimension(argmax_dim);
+ argmax->input(conv);
+ }
+
+ // Create Reshape to 4-rank back, because argmax decrease rank of tensor by 1
+ auto reshape = graph->nodes()->create<luci::CircleReshape>();
+ {
+ init_name_and_origin(reshape, name + "/Reshape", origin);
+
+ auto shape = create_shape_tensor(graph, {1, output_height, output_width, 1});
+ init_name_and_origin(shape, reshape->name() + "/Shape", origin);
+
+ reshape->tensor(argmax);
+ reshape->shape(shape);
+ }
+
+ // Create Cast to use float32 instead int32
+ auto argmax_cast = create_cast(reshape, loco::DataType::S32, loco::DataType::FLOAT32);
+ init_name_and_origin(argmax_cast, argmax->name() + "/Cast", origin);
+
+ return argmax_cast;
+}
+
+// Creates "identity operation" after Floor
+// to force circle-quantizer requantize output tensor with scale << 1.
+//
+// Dealing with values of extremely different scales
+// in following binary operations hurts backend precision.
+luci::CircleNode *create_post_floor_requantize_node(luci::CircleFloor *floor)
+{
+ auto graph = floor->graph();
+ auto const origin = luci::get_origin(floor);
+ auto name = floor->name();
+
+ // Use DepthwiseConv2D with identity filter as an "identity operation".
+ //
+ // This operation do not change values, but forces circle-quantizer to use
+ // statistics to compute qparam scale instead of fixed scale == 1.0 after floor.
+ // DepthwiseConv2d is not eliminated by optimizations,
+ // so desired scale will reach backend.
+ auto requantizer = none_act_func(graph->nodes()->create<luci::CircleDepthwiseConv2D>());
+ init_name_and_origin(requantizer, name + "/Requantizer", origin);
+
+ requantizer->input(floor);
+
+ auto requantizer_filter = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f);
+ init_name_and_origin(requantizer_filter, name + "/Requantizer/filter", origin);
+ requantizer_filter->rank(4);
+ for (uint32_t i = 0; i < 4; ++i)
+ {
+ requantizer_filter->dim(i) = 1;
+ }
+ requantizer->filter(requantizer_filter);
+
+ auto requantizer_bias = create_zero_bias(graph, 1);
+ init_name_and_origin(requantizer_bias, name + "/Requantizer/bias", origin);
+ requantizer->bias(requantizer_bias);
+
+ requantizer->padding(luci::Padding::VALID);
+ requantizer->stride()->w(1);
+ requantizer->stride()->h(1);
+ requantizer->depthMultiplier(1);
+ requantizer->dilation()->w(1);
+ requantizer->dilation()->h(1);
+
+ return requantizer;
+}
+
+luci::CircleNode *window_y_coord(const std::string &name, const luci::Filter &filter,
+ luci::CircleNode *flattened)
+{
+ auto const graph = flattened->graph();
+ auto const origin = luci::get_origin(flattened);
+
+ auto div = none_act_func(graph->nodes()->create<luci::CircleMul>());
+ {
+ init_name_and_origin(div, name + "/Div", origin);
+
+ // Adjustment_coeff is needed to fix computation of quantized tensors
+ //
+ // For example float32 value 2.0 could be quantized to 1.996
+ // after floor it will be transformed to 1.0, but desired answer is still something close to 2.0
+ //
+ // rounding_adjustment is chosen so it is small enough to not affect float32 computations,
+ // but "Div" change is larger then potential quantization error.
+ //
+ // This computation exploits the fact that div is an x coord in maxpool window,
+ // and lies in defined range [0, filter.h())
+ const float rounding_adjustment = 1.0f / (filter.w() * filter.h());
+ const float divider_value = filter.w() - rounding_adjustment;
+ auto divider = create_scalar<loco::DataType::FLOAT32>(graph, 1.0f / divider_value);
+ init_name_and_origin(divider, div->name() + "/Divider", origin);
+
+ div->x(flattened);
+ div->y(divider);
+ }
+
+ auto floor = graph->nodes()->create<luci::CircleFloor>();
+ {
+ init_name_and_origin(floor, name + "/Floor", origin);
+ floor->x(div);
+ }
+
+ auto requantizer = create_post_floor_requantize_node(floor);
+
+ return requantizer;
+}
+
+luci::CircleNode *window_x_coord(const std::string &name, float filter_width,
+ luci::CircleNode *flattened, luci::CircleNode *y_coord)
+{
+ auto const graph = flattened->graph();
+ auto const origin = luci::get_origin(flattened);
+
+ auto mod = none_act_func(graph->nodes()->create<luci::CircleAdd>());
+ {
+ init_name_and_origin(mod, name + "/Mod", origin);
+
+ auto neg = graph->nodes()->create<luci::CircleNeg>();
+ {
+ init_name_and_origin(neg, mod->name() + "/Neg", origin);
+
+ auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
+ {
+ init_name_and_origin(mul, neg->name() + "/Neg", origin);
+
+ auto multipler = create_scalar<loco::DataType::FLOAT32>(graph, filter_width);
+ init_name_and_origin(multipler, mul->name() + "/Multipler", origin);
+
+ mul->x(y_coord);
+ mul->y(multipler);
+ }
+
+ neg->x(mul);
+ }
+
+ mod->x(flattened);
+ mod->y(neg);
+ }
+
+ return mod;
+}
+
+luci::CircleNode *plane_flattened_coord(const std::string &name, uint32_t input_width,
+ luci::CircleNode *y_coord, luci::CircleNode *x_coord,
+ luci::CircleNode *corners)
+{
+ auto const graph = corners->graph();
+ auto const origin = luci::get_origin(corners);
+
+ auto add = none_act_func(graph->nodes()->create<luci::CircleAdd>());
+ {
+ init_name_and_origin(add, name + "/Add", origin);
+
+ auto addition = none_act_func(graph->nodes()->create<luci::CircleAdd>());
+ {
+ init_name_and_origin(addition, add->name() + "/Add", origin);
+
+ auto y_addition = none_act_func(graph->nodes()->create<luci::CircleMul>());
+ {
+ init_name_and_origin(y_addition, addition->name() + "/Mul", origin);
+
+ auto width_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_width);
+ init_name_and_origin(width_scalar, y_addition->name() + "/Const", origin);
+
+ y_addition->x(y_coord);
+ y_addition->y(width_scalar);
+ }
+
+ addition->x(x_coord);
+ addition->y(y_addition);
+ }
+
+ add->x(addition);
+ add->y(corners);
+ }
+
+ return add;
+}
+
+luci::CircleNode *volume_flattened_coords(const std::string &name, uint32_t channel,
+ uint32_t input_depth, luci::CircleNode *plane)
+{
+ auto const graph = plane->graph();
+ auto const origin = luci::get_origin(plane);
+
+ // Create Mul
+ auto mul = none_act_func(graph->nodes()->create<luci::CircleMul>());
+ {
+ init_name_and_origin(mul, name + "/Mul", origin);
+
+ auto depth_scalar = create_scalar<loco::DataType::FLOAT32>(graph, input_depth);
+ init_name_and_origin(depth_scalar, mul->name() + "/Const", origin);
+
+ mul->x(plane);
+ mul->y(depth_scalar);
+ }
+
+ luci::CircleNode *volume = mul;
+
+ // Add channel number to output
+ if (channel > 0)
+ {
+ // Create Add
+ auto add_ch = none_act_func(graph->nodes()->create<luci::CircleAdd>());
+ init_name_and_origin(add_ch, name + "/Add_Channel", origin);
+
+ auto channel_scalar = create_scalar<loco::DataType::FLOAT32>(graph, channel);
+ init_name_and_origin(channel_scalar, add_ch->name() + "/Const", origin);
+
+ add_ch->x(mul);
+ add_ch->y(channel_scalar);
+
+ volume = add_ch;
+ }
+
+ return volume;
+}
+
+luci::CircleNode *argmax_branch(luci::Padding padding, const luci::Stride &stride,
+ const luci::Filter filter, luci::CircleCustom *cop)
+{
+ auto graph = cop->graph();
+ auto input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
+ auto output = get_custom_output(cop, 1);
+
+ auto const depth_dimension = 3;
+ auto const input_depth = input->dim(depth_dimension).value();
+ auto const input_height = input->dim(1).value();
+ auto const input_width = input->dim(2).value();
+
+ assert(output->rank() == 4);
+ auto const output_height = output->dim(1).value();
+ auto const output_width = output->dim(2).value();
+
+ auto origin = luci::get_origin(cop);
+ auto name = cop->name() + "/Argmax";
+
+ // Create Split
+ auto split = graph->nodes()->create<luci::CircleSplit>();
+ {
+ init_name_and_origin(split, name + "/Split", origin);
+
+ // Create split_dim
+ auto split_dim = create_scalar<loco::DataType::S32>(graph, depth_dimension);
+ init_name_and_origin(split_dim, split->name() + "/Dim", origin);
+
+ split->num_split(int32_t(input_depth));
+
+ split->split_dim(split_dim);
+ split->input(input);
+ }
+
+ /**
+ * Note: we need define idx from input_tensor of maximum element in MaxPool's sliding window.
+ * For this we split input tensor by channels, define idx in sliding window and convert this idx
+ * to idx from source input_tensor using FloorDiv, Mul and Add operations with constant tensors.
+ */
+ std::vector<luci::CircleNode *> branch_outputs(input_depth);
+
+ for (uint32_t br_n = 0; br_n < input_depth; ++br_n)
+ {
+ auto const branch_name = name + "/depth_" + std::to_string(br_n);
+
+ // Create CircleSplitOut
+ auto split_out = graph->nodes()->create<luci::CircleSplitOut>();
+ init_name_and_origin(split_out, branch_name + "/SplitOut", origin);
+ split_out->index(int32_t(br_n));
+ split_out->input(split);
+
+ // Define idx of max element in Window:
+ auto window_coords =
+ window_flattened_coord(branch_name + "/WindowFlat", padding, stride, filter, input_height,
+ input_width, output_height, output_width, split_out);
+
+ auto const window_y = window_y_coord(branch_name + "/WindowY", filter, window_coords);
+ auto const window_x =
+ window_x_coord(branch_name + "/WindowX", filter.w(), window_coords, window_y);
+
+ // Define idx of max element in Plane
+ // This tensor contains coords of left top corners for each window from input tensor
+ auto corners = create_coords_addition(graph, padding, stride, filter, input_height, input_width,
+ input_depth, output_height, output_width);
+ init_name_and_origin(corners, branch_name + "/Const", origin);
+
+ auto plane_coord =
+ plane_flattened_coord(branch_name + "/PlaneFlat", input_width, window_y, window_x, corners);
+
+ // Define volume coords as final value
+ branch_outputs[br_n] =
+ volume_flattened_coords(branch_name + "/VolumeFlat", br_n, input_depth, plane_coord);
+ }
+
+ // Create Concatenation
+ auto concat = none_act_func(graph->nodes()->create<luci::CircleConcatenation>(input_depth));
+ {
+ init_name_and_origin(concat, name + "/Concatenation", origin);
+ concat->axis(depth_dimension);
+
+ for (uint32_t i = 0; i < input_depth; ++i)
+ {
+ concat->values(i, branch_outputs[i]);
+ }
+ }
+
+ // Output of argmax_with_maxpool should be S64 or S32
+ loco::DataType output_dtype = get_custom_output(cop, 1)->dtype();
+ auto output_cast = create_cast(concat, loco::DataType::FLOAT32, output_dtype);
+ init_name_and_origin(output_cast, name + "/Cast", origin);
+
+ return output_cast;
+}
+
+bool resolve_max_pool_with_argmax(luci::CircleCustom *cop)
+{
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+ const std::vector<uint8_t> custom_options = cop->custom_options();
+ auto map = flexbuffers::GetRoot(custom_options).AsMap();
+
+ // Define params
+ // Note: Only `Targmax` equal to DT_INT64 is supported by tflite converter
+ // Note: Only `data_format` equal to "NHWC" is supported by tflite converter
+ // TODO add support of `include_batch_in_index` param
+ auto ksize_param = to_vector<uint32_t>(map["ksize"].AsTypedVector());
+ auto strides_param = to_vector<uint32_t>(map["strides"].AsTypedVector());
+ auto padding_param = map["padding"].As<std::string>();
+
+ // Batch size and depth of ksize more than 1 is not supported.
+ CHECK_OR_FALSE(ksize_param.size() == 4);
+ CHECK_OR_FALSE(ksize_param[0] == 1 && ksize_param[3] == 1);
+
+ CHECK_OR_FALSE(strides_param.size() == 4);
+ CHECK_OR_FALSE(strides_param[0] == 1 && strides_param[3] == 1);
+
+ // define Padding
+ auto padding = string_to_padding(padding_param);
+
+ // define Filter
+ luci::Filter filter;
+ filter.h(ksize_param[1]);
+ filter.w(ksize_param[2]);
+
+ // define Stride
+ luci::Stride stride;
+ stride.h(strides_param[1]);
+ stride.w(strides_param[2]);
+
+ // input node
+ auto const input = loco::must_cast<luci::CircleNode *>(cop->inputs(0));
+ CHECK_OR_FALSE(input->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(input->rank() == 4);
+
+ // TODO support batch size > 1 and `include_batch_in_index` option
+ CHECK_OR_FALSE(input->dim(0).value() == 1);
+
+ // output nodes
+ auto const outputs = loco::succs(cop);
+ CHECK_OR_FALSE(outputs.size() == 2);
+ assert(outputs.size() == cop->numOutputs());
+
+ auto output0 = get_custom_output(cop, 0);
+ auto output1 = get_custom_output(cop, 1);
+
+ // From TF documentation: output of maxpool must has same type as input
+ assert(output0->dtype() == input->dtype());
+ assert(output1->dtype() == loco::DataType::S64 || output1->dtype() == loco::DataType::S32);
+
+ // Create MaxPool
+ auto maxpool = max_pool_branch(padding, stride, filter, cop);
+ auto argmax = argmax_branch(padding, stride, filter, cop);
+
+ // last argmax branch op is cast, it should have dtype initialized
+ assert(argmax->dtype() == output1->dtype());
+
+ // replace old node with new subgraph
+ cop->inputs(0, nullptr);
+ loco::replace(output0).with(maxpool);
+ loco::replace(output1).with(argmax);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CUSTOM(MaxPoolWithArgmax)]
+ * | |
+ * [MaxPool output] [Argmax output]
+ *
+ * AFTER
+ * |
+ * [CircleNode]
+ * / \
+ * [Split over channels] [MaxPool2D]
+ * / | \ \
+ * [Requantize] ... ... [MaxPool output]
+ * |
+ * [PadV2]
+ * |
+ * [Conv2D]
+ * |
+ * [ArgMax]
+ * |
+ * [Reshape to 4d]
+ * |
+ * [Cast to float32]
+ * / |
+ * | [Mul 1/<window width>]
+ * | \
+ * | [Floor]
+ * | |
+ * | [DepthwiseConv2D for requantize]
+ * | / \
+ * | [Mul window width] |
+ * \ / /
+ * \ [Neg] [Mul input width]
+ * \ / /
+ * [Add] /
+ * \ /
+ * [Add]
+ * |
+ * [Add const]
+ * |
+ * [Mul number of channels]
+ * \
+ * [Optional Add with channels id] ... ...
+ * \ | /
+ * [Concatenation]
+ * |
+ * [Cast to int]
+ * |
+ * [Argmax output]
+ */
+bool ResolveCustomOpMaxPoolWithArgmaxPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cop = dynamic_cast<luci::CircleCustom *>(node);
+ if (not cop)
+ continue;
+
+ if (cop->custom_code() != "MaxPoolWithArgmax")
+ continue;
+
+ if (!resolve_max_pool_with_argmax(cop))
+ continue;
+
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(ResolveCustomOpMaxPoolWithArgmaxPassTest, name)
+{
+ luci::ResolveCustomOpMaxPoolWithArgmaxPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstitutePadV2ToPadPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <vector>
+
+/**
+ * @brief Convert PadV2 op in a certain condition to Pad
+ * @details Condition to convert PadV2 to Pad is like below:
+ *
+ * Basic Condition)
+ *
+ * C1) For all i, PadV2.input[i] >= 0
+ * C2) For all c, PadV2.constant_values[c] <= 0
+ * C3) PadV2 == MaxPool2D.value()
+ * C4) number of padded values at left < MaxPool2D.Filter.W
+ * number of padded values at right < MaxPool2D.Filter.W
+ * number of padded values at top < MaxPool2D.Filter.H
+ * number of padded values at bottom < MaxPool2D.Filter.H
+ *
+ * Example graph is as follows:
+ *
+ * %1 = CircleRelu # relu_output[i] >= 0
+ * %2 = CirclePadV2(%1, constant_values <= 0)
+ * %3 = CircleMaxPool2D(%2, ...) # output will be chosen from relu_output
+ *
+ * In this case, it's OK to replace PadV2 with Pad, which uses 0 as padding constant.
+ *
+ * Optional Condition)
+ *
+ * Terminology)
+ * - 'reshaping op' : op that does not change the value of tensor
+ * but changes position of tensor value, e.g., Transpose, Reshape, Slice, etc.
+ *
+ * C5) Input of PadV2 could be 'reshaping op'. Example is as follow:
+ *
+ * %1 = CircleRelu # output[i] >= 0
+ * %2 = CircleTranspose(%1) # reshaping op
+ * ... # more reshaping ops
+ * %n = CirclePadV2(%n-1, constant_values <= 0)
+ * %n+1 = CircleMaxPool2D(%n, ...)
+ *
+ * C6) PadV2 could be an input of 'reshaping op'. Example is as follow:
+ *
+ * %1 = CircleRelu
+ * %2 = CirclePadV2(%1, constant_values <= 0)
+ * %3 = CircleTranspose(%2) # reshaping op
+ * ... # more reshaping ops
+ * %n = CircleMaxPool2D(%n-1, ...)
+ *
+ * Why is this pass required?
+ *
+ * When PyTorch model is converted into Circle model, sometimes PadV2 is inserted with
+ * the following pattern:
+ *
+ * %1 = Circle.Conv2D(..., activation = Relu)
+ * %2 = Circle.Transpose(%1, perm=[0,3,1,2])
+ * %3 = Circle.PadV2(%2, constant_values = -3.4028234663852886e+38)
+ * %4 = Circle.Transpose(%3, perm=[0,2,3,1])
+ * %5 = Circle.MaxPool2D(%4, filter=[3,3], padding="VALID")
+ *
+ * Large negative padding constant of %3 caused problem when we quantized this model.
+ * So we need to convert the negative number to some number in reasonable range for
+ * quantization, e.g., zero.
+ */
+namespace
+{
+
+struct Paddings
+{
+ struct Pad
+ {
+ int32_t front;
+ int32_t end;
+ };
+ /**
+ * @brief Store paddings position information.
+ * @details _padding_pos[k] stores Pad object at axis k
+ *
+ * @note Paddings must be for rank 4 tensor
+ */
+ std::vector<Pad> _padding_pos;
+
+ Paddings(luci::CircleConst *paddings)
+ {
+ assert(paddings->dtype() == loco::DataType::S32);
+ assert(paddings->rank() == 2);
+ assert(paddings->dim(1).value() == 2);
+ assert(paddings->size<loco::DataType::S32>() == paddings->rank() * 4);
+
+ for (uint32_t i = 0; i < paddings->dim(0).value(); i++)
+ {
+ Pad pad{.front = paddings->at<loco::DataType::S32>(i * 2),
+ .end = paddings->at<loco::DataType::S32>(i * 2 + 1)};
+ _padding_pos.emplace_back(pad);
+ }
+
+ assert(_padding_pos.size() == 4);
+ }
+
+ /**
+ * @brief Check if this padding area is covered by filter
+ *
+ * @note This is to check condition C4).
+ * _padding_pos should store values according to NHWC.
+ */
+ bool smaller_than(int32_t filter_h, int32_t filter_w)
+ {
+ auto &pad_H = _padding_pos.at(1);
+ auto &pad_W = _padding_pos.at(2);
+
+ return (pad_H.front < filter_h) && (pad_H.end < filter_h) && (pad_W.front < filter_w) &&
+ (pad_W.end < filter_w);
+ }
+
+ /**
+ * @brief Track how paddings change after CircleTranspose
+ * @details Consider the following graph,
+ *
+ * %1 = Circle.Input
+ * %2 = Circle.PadV2(%1,
+ * paddings=[[0, 0], [0, 0], [2, 3], [4, 5]],
+ * padding_value = -100)
+ * %3 = Circle.Transpose(%2, perm[0, 2, 3, 1])
+ *
+ * Output of %3 has padding constant value(-100) from %2 at position below:
+ *
+ * - axis | front | end
+ * ------|-------|-----
+ * 0 | 0 | 0
+ * 1 | 2 | 3
+ * 2 | 4 | 5
+ * 3 | 0 | 0
+ *
+ * This method keeps track of such change of padding position.
+ */
+ void apply(luci::CircleTranspose *transpose)
+ {
+ assert(transpose);
+ luci::CircleConst *perm = loco::must_cast<luci::CircleConst *>(transpose->perm());
+
+ std::vector<Pad> transposed_pos;
+ transposed_pos.resize(4);
+
+ for (uint32_t to = 0; to < 4; to++)
+ {
+ int32_t from = perm->at<loco::DataType::S32>(to);
+ transposed_pos.at(to) = _padding_pos.at(from);
+ }
+
+ _padding_pos = transposed_pos;
+ }
+};
+
+struct ReshapingNode
+{
+ /// @brief Check if node is 'reshaping op'
+ static bool check(loco::Node *node)
+ {
+ if (dynamic_cast<luci::CircleTranspose *>(node))
+ return true;
+ // add more 'reshaping op'
+
+ return false;
+ }
+
+ /// @brief Retuen reshaping op's input
+ static loco::Node *input(loco::Node *node)
+ {
+ if (auto transpose = dynamic_cast<luci::CircleTranspose *>(node))
+ return transpose->a();
+ // add more 'reshaping op'
+
+ throw std::runtime_error("Not yet supported reshaping op");
+ }
+};
+
+/// @brief Get only successor node
+loco::Node *get_only_succ(loco::Node *parent)
+{
+ assert(parent);
+
+ auto successors = loco::succs(parent);
+ if (successors.size() != 1)
+ return nullptr;
+
+ return *successors.begin();
+}
+
+// Check condition C1) and C5)
+bool positive_or_zero(loco::Node *ifm)
+{
+ assert(ifm);
+
+ if (ReshapingNode::check(ifm))
+ return positive_or_zero(ReshapingNode::input(ifm));
+
+ // Since Relu.output[i] >= 0
+ if (dynamic_cast<luci::CircleRelu *>(ifm))
+ return true;
+ if (auto conv = dynamic_cast<luci::CircleConv2D *>(ifm))
+ {
+ if (conv->fusedActivationFunction() == luci::FusedActFunc::RELU)
+ return true;
+ // Add more FusedActFunc
+ }
+ // Add more ops of which output[i] >= 0
+
+ return false;
+}
+
+template <loco::DataType DT> bool has_all_positive_values(luci::CircleConst *node)
+{
+ // Only numeric datatype is allowed
+ static_assert(DT != loco::DataType::Unknown);
+ static_assert(DT != loco::DataType::STRING);
+
+ assert(node);
+
+ auto size = node->size<DT>();
+ for (decltype(size) t = 0; t < size; t++)
+ {
+ typename loco::DataTypeImpl<DT>::Type val = node->at<DT>(t);
+ if (val <= 0)
+ return false;
+ }
+
+ return true;
+}
+
+// To check condition C2)
+bool has_all_positive_values(luci::CircleConst *node)
+{
+ assert(node);
+
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return has_all_positive_values<loco::DataType::FLOAT32>(node);
+ // Add more datatype
+
+ throw std::runtime_error("Not yet supported datatype");
+}
+
+bool used_by_maxpool_only(luci::CircleNode *node, Paddings &paddings)
+{
+ auto successor = get_only_succ(node);
+
+ // when successor is not only-succ
+ if (successor == nullptr)
+ return false;
+
+ if (auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(successor))
+ {
+ // Let's check condition C4)
+ return paddings.smaller_than(maxpool->filter()->h(), maxpool->filter()->w());
+ }
+
+ // Let's check condition C6)
+ if (auto transpose = dynamic_cast<luci::CircleTranspose *>(successor))
+ {
+ auto appropriate = [](luci::CircleTranspose *transpose) {
+ luci::CircleConst *perm = loco::must_cast<luci::CircleConst *>(transpose->perm());
+
+ // For Transpose to be an input for MaxPool2D
+ return (transpose->rank() == 4) && (perm && perm->dtype() == loco::DataType::S32) &&
+ (perm->size<loco::DataType::S32>() == 4);
+ };
+
+ if (not appropriate(transpose))
+ return false;
+
+ paddings.apply(transpose);
+ return used_by_maxpool_only(transpose, paddings);
+ }
+ // Support more 'reshaping op' later
+
+ return false;
+}
+
+// Check condition C3), C4) and C6)
+bool used_by_maxpool_only(luci::CirclePadV2 *pad_v2)
+{
+ // For PadV2 to be an input for MaxPool2D
+ if (pad_v2->rank() != 4)
+ return false;
+
+ Paddings paddings(loco::must_cast<luci::CircleConst *>(pad_v2->paddings()));
+
+ return used_by_maxpool_only(pad_v2, paddings);
+}
+
+loco::Node *build_pad_from(luci::CirclePadV2 *pad_v2)
+{
+ auto copy_shape = [](const luci::CircleNode *src, luci::CircleNode *dest) {
+ auto rank = src->rank();
+ dest->rank(rank);
+
+ for (decltype(rank) axis = 0; axis < rank; axis++)
+ dest->dim(axis) = src->dim(axis);
+ };
+
+ auto g = pad_v2->graph();
+
+ auto pad = g->nodes()->create<luci::CirclePad>();
+ {
+ pad->name(pad_v2->name() + "/pad");
+ luci::add_origin(pad, luci::get_origin(pad_v2));
+
+ pad->dtype(pad_v2->dtype());
+ copy_shape(pad_v2, pad);
+
+ pad->input(pad_v2->input());
+ pad->paddings(pad_v2->paddings());
+ }
+
+ return pad;
+}
+
+luci::CirclePadV2 *get_padv2(loco::Node *node)
+{
+ if (auto padv2 = dynamic_cast<luci::CirclePadV2 *>(node))
+ return padv2;
+
+ if (ReshapingNode::check(node))
+ return get_padv2(ReshapingNode::input(node));
+
+ return nullptr;
+}
+
+bool substitute_padv2_to_pad(luci::CircleMaxPool2D *maxp)
+{
+ // precondition
+ assert(maxp);
+ assert(maxp->value());
+
+ auto pad_v2 = get_padv2(maxp->value());
+
+ if (pad_v2 == nullptr)
+ return false;
+
+ assert(pad_v2->input());
+
+ auto paddings = loco::must_cast<luci::CircleConst *>(pad_v2->paddings());
+ auto constant_values = loco::must_cast<luci::CircleConst *>(pad_v2->constant_values());
+
+ (void)paddings;
+ assert(paddings);
+ assert(paddings->dtype() == loco::DataType::S32);
+ assert(constant_values);
+ assert(constant_values->dtype() == pad_v2->dtype());
+
+ if (not positive_or_zero(pad_v2->input()))
+ return false;
+
+ if (has_all_positive_values(constant_values))
+ return false;
+
+ if (not used_by_maxpool_only(pad_v2))
+ return false;
+
+ auto pad = build_pad_from(pad_v2);
+
+ replace(pad_v2).with(pad);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Case 1) Basic case
+ *
+ * BEFORE
+ * [CircleRelu]
+ * |
+ * | [CircleConst] [CircleConst]
+ * | | |
+ * -------+----------------------
+ * |
+ * [CirclePadV2]
+ * |
+ * [CircleMaxPool2D]
+ * |
+ *
+ * AFTER
+ * [CircleRelu]
+ * |
+ * | [CircleConst] [CircleNode] [CircleConst]
+ * | | | | |
+ * -------+------- -------------+--------------+
+ * | |
+ * [CirclePad] [CirclePadV2]
+ * |
+ * [CircleMaxPool2D]
+ * |
+ *
+ * Case 2) During conversion from a PyTorch model into a Circle model,
+ * it is common that some 'Reshaping op', e.g., CircleTranspose,
+ * are inserted in-between operations to swith NCHW into NHWC and vice versa.
+ * This pass also needs to handle such situation.
+ *
+ * BEFORE
+ * [CircleRelu]
+ * |
+ * | [CircleConst] [CircleConst]
+ * | | |
+ * -------+----------------------
+ * |
+ * [CircleTranspose]
+ * |
+ * [CirclePadV2]
+ * |
+ * [CircleTranspose]
+ * |
+ * [CircleMaxPool2D]
+ * |
+ *
+ * AFTER
+ * [CircleRelu]
+ * |
+ * | [CircleConst] [CircleNode] [CircleConst]
+ * | | | | |
+ * -------+------- -------------+--------------+
+ * | |
+ * [CircleTranspose] [CirclePadV2]
+ * |
+ * [CirclePad]
+ * |
+ * [CircleTranspose]
+ * |
+ * [CircleMaxPool2D]
+ * |
+ */
+bool SubstitutePadV2ToPadPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto circle_node = dynamic_cast<luci::CircleMaxPool2D *>(node))
+ {
+ if (substitute_padv2_to_pad(circle_node))
+ {
+ changed = true;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstitutePadV2ToPadPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using UIntList = std::initializer_list<uint32_t>;
+using IntList = std::initializer_list<int32_t>;
+
+// convert shape in UIntList to loco::TensorShape
+std::unique_ptr<loco::TensorShape> tensor_shape(const UIntList &values)
+{
+ auto shape = std::make_unique<loco::TensorShape>();
+ {
+ shape->rank(values.size());
+
+ uint32_t r = 0;
+ for (auto v : values)
+ shape->dim(r++).set(v);
+ }
+ return shape;
+}
+
+class TestGraph
+{
+public:
+ void init(const UIntList &input_shape, const UIntList &output_shape)
+ {
+ _input = _g.nodes()->create<luci::CircleInput>();
+ {
+ _input->name("input");
+ _input->dtype(loco::DataType::FLOAT32);
+ _input->shape(input_shape);
+
+ auto graph_input = _g.inputs()->create();
+ {
+ _input->index(graph_input->index());
+ graph_input->shape(std::move(tensor_shape(input_shape)));
+ }
+ }
+
+ _output = _g.nodes()->create<luci::CircleOutput>();
+ {
+ _output->name("output");
+ _output->dtype(loco::DataType::FLOAT32);
+ _output->shape(output_shape);
+
+ auto graph_output = _g.outputs()->create();
+ {
+ _output->index(graph_output->index());
+ graph_output->shape(std::move(tensor_shape(output_shape)));
+ }
+ }
+
+ // subclass should implement build_body()
+ auto graphlet_before_output = build_body(_input);
+
+ _output->from(graphlet_before_output);
+ }
+
+ // build luci::CircleConst for paddings
+ luci::CircleConst *paddings_const(const std::vector<int32_t> &plist)
+ {
+ assert(plist.size() == 8);
+
+ auto node = _g.nodes()->create<luci::CircleConst>();
+ {
+ node->dtype(loco::DataType::S32);
+ node->shape({4, 2});
+ node->size<loco::DataType::S32>(8);
+
+ for (int32_t t = 0; t < 8; t++)
+ node->at<loco::DataType::S32>(t) = plist.at(t);
+ }
+
+ return node;
+ }
+
+ // build luci::CircleConst for paddings value
+ luci::CircleConst *padding_val_const(float val)
+ {
+ auto node = _g.nodes()->create<luci::CircleConst>();
+ {
+ node->dtype(loco::DataType::FLOAT32);
+ node->shape({1});
+ node->size<loco::DataType::FLOAT32>(1);
+
+ node->at<loco::DataType::FLOAT32>(0) = val;
+ }
+
+ return node;
+ }
+
+ // build luci::CirclePadV2
+ luci::CirclePadV2 *padV2(loco::Node *input, const std::vector<int32_t> &paddings,
+ float padding_constant)
+ {
+ auto padv2 = _g.nodes()->create<luci::CirclePadV2>();
+ {
+ padv2->name("PadV2");
+ padv2->dtype(loco::DataType::FLOAT32);
+
+ padv2->input(input);
+ padv2->paddings(paddings_const(paddings));
+ padv2->constant_values(padding_val_const(padding_constant));
+ // No shape setting. ShapeInference should be run later
+ }
+ return padv2;
+ }
+
+ // build luci::CircleMaxPool2D
+ luci::CircleMaxPool2D *maxpool2d(loco::Node *input,
+ const std::pair<uint32_t, uint32_t> &kernel_HW)
+ {
+ auto mp = _g.nodes()->create<luci::CircleMaxPool2D>();
+ {
+ mp->value(input);
+ mp->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mp->padding(luci::Padding::VALID);
+ mp->filter()->h(kernel_HW.first);
+ mp->filter()->w(kernel_HW.second);
+ mp->stride()->h(1);
+ mp->stride()->w(1);
+
+ mp->dtype(loco::DataType::FLOAT32);
+ // No shape setting. ShapeInference should be run later
+ }
+ return mp;
+ }
+
+ // build luci::CircleRelu
+ luci::CircleRelu *relu(loco::Node *input)
+ {
+ auto relu = _g.nodes()->create<luci::CircleRelu>();
+ {
+ relu->features(input);
+ relu->dtype(loco::DataType::FLOAT32);
+ // No shape setting. ShapeInference should be run later
+ }
+ return relu;
+ }
+
+ // build luci::CircleTranspose
+ luci::CircleTranspose *transpose(loco::Node *input, const std::vector<int32_t> &perm_v)
+ {
+ auto perm = _g.nodes()->create<luci::CircleConst>();
+ {
+ auto rank = static_cast<uint32_t>(perm_v.size());
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(rank);
+ perm->shape({rank});
+ for (decltype(rank) d = 0; d < rank; d++)
+ perm->at<loco::DataType::S32>(d) = perm_v.at(d);
+ }
+ auto transpose_node = _g.nodes()->create<luci::CircleTranspose>();
+ {
+ transpose_node->a(input);
+ transpose_node->perm(perm);
+ transpose_node->dtype(loco::DataType::S32);
+ // No shape setting. ShapeInference should be run later
+ }
+ return transpose_node;
+ }
+
+ loco::Graph *g() { return &_g; }
+ luci::CircleOutput *output() { return _output; }
+
+ virtual loco::Node *build_body(loco::Node *input) = 0;
+
+private:
+ loco::Graph _g;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+class SubstitutePadV2ToPadPassTest : public ::testing::Test
+{
+public:
+ SubstitutePadV2ToPadPassTest() = default;
+
+ bool run_pass(loco::Graph *g)
+ {
+ _shapeinf_pass.run(g);
+
+ return _pad_pass.run(g);
+ }
+
+protected:
+ luci::SubstitutePadV2ToPadPass _pad_pass;
+ luci::CircleShapeInferencePass _shapeinf_pass;
+};
+
+} // namespace
+
+/**
+ * Graph that is changed by SubstitutePadV2ToPadPass
+ *
+ * [CircleInput]
+ * |
+ * [Relu]
+ * |
+ * [CirclePadV2] pad.H.front = 1, pad.H.end = 1, pad.W.front = 1, pad.W.end = 1
+ * |
+ * [MaxPool2D] filter.H = 2, filter.W = 2
+ * |
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, basic_case)
+{
+ struct Graph_basic : public TestGraph
+ {
+ Graph_basic()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 6, 6, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ auto relu_node = relu(input);
+
+ IntList paddings = {0, 0, 1, 1, 1, 1, 0, 0};
+ auto padding_const = -10.0;
+ auto padV2_node = padV2(relu_node, paddings, padding_const);
+
+ return maxpool2d(padV2_node, {2, 2});
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+ ASSERT_TRUE(result);
+
+ // Checking CircleMaxPool2D
+ auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(graph.output()->from());
+ ASSERT_TRUE(maxpool != nullptr);
+
+ // Checking CirclePad
+ auto pad = dynamic_cast<luci::CirclePad *>(maxpool->value());
+ ASSERT_TRUE(pad != nullptr);
+
+ // Checking CircleRelu
+ auto relu = dynamic_cast<luci::CircleRelu *>(pad->input());
+ ASSERT_TRUE(relu != nullptr);
+
+ auto input = dynamic_cast<luci::CircleInput *>(relu->features());
+ ASSERT_TRUE(input != nullptr);
+}
+
+/**
+ * Graph that is changed by SubstitutePadV2ToPadPass
+ *
+ * Transpose ops are inserted, e.g., to switch layout between NHWC and NCHW
+ *
+ * [CircleInput]
+ * |
+ * [Relu]
+ * | 1x4x4x3 (NHWC)
+ * [Transpose] perm=[0,3,1,2]
+ * | 1x3x4x4 (NCHW)
+ * [CirclePadV2] paddings=[0,0,0,0,1,1,1,1]
+ * | 1x3x6x6 (NCHW)
+ * [Transpose] perm=[0,2,3,1]
+ * | 1x6x6x3 (NHWC)
+ * [MaxPool2D] filter.H = 3, filter.W = 3
+ * | 1x4x4x3 (NHWC)
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, reshaping_op_case)
+{
+ struct Graph_Reshaping_Op : public TestGraph
+ {
+ Graph_Reshaping_Op()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 4, 4, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ auto relu_node = relu(input);
+
+ auto transpose1_node = transpose(relu_node, {0, 3, 1, 2});
+
+ IntList paddings = {0, 0, 0, 0, 1, 1, 1, 1};
+ auto padding_const = -10.0;
+ auto padV2_node = padV2(transpose1_node, paddings, padding_const);
+
+ auto transpose2_node = transpose(padV2_node, {0, 2, 3, 1});
+
+ return maxpool2d(transpose2_node, {3, 3});
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+ ASSERT_TRUE(result);
+
+ // Checking CircleMaxPool2D
+ auto maxpool = dynamic_cast<luci::CircleMaxPool2D *>(graph.output()->from());
+ ASSERT_TRUE(maxpool != nullptr);
+
+ // Checking Transpose
+ auto transpose1 = dynamic_cast<luci::CircleTranspose *>(maxpool->value());
+ ASSERT_TRUE(transpose1 != nullptr);
+
+ // Checking CirclePad
+ auto pad = dynamic_cast<luci::CirclePad *>(transpose1->a());
+ ASSERT_TRUE(pad != nullptr);
+
+ // Checking Transpose
+ auto transpose2 = dynamic_cast<luci::CircleTranspose *>(pad->input());
+ ASSERT_TRUE(transpose2 != nullptr);
+
+ // Checking CircleRelu
+ auto relu = dynamic_cast<luci::CircleRelu *>(transpose2->a());
+ ASSERT_TRUE(relu != nullptr);
+
+ auto input = dynamic_cast<luci::CircleInput *>(relu->features());
+ ASSERT_TRUE(input != nullptr);
+}
+
+//
+// Negative Tests
+//
+
+/**
+ * Graph that is not changed by SubstitutePadV2ToPadPass
+ *
+ * [CircleInput]
+ * |
+ * [CirclePadV2]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, no_relu_maxpool_NEG)
+{
+ struct Graph_No_MaxPool : public TestGraph
+ {
+ Graph_No_MaxPool()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 6, 8, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ IntList paddings = {0, 0, 1, 1, 2, 2, 0, 0};
+ auto padding_const = -10.0;
+ return padV2(input, paddings, padding_const);
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+
+ ASSERT_FALSE(result);
+}
+
+/**
+ * Graph that is not changed by SubstitutePadV2ToPadPass
+ *
+ * There is no CircleMaxPool2D.
+ *
+ * [CircleInput]
+ * |
+ * [CircleRelu]
+ * |
+ * [CirclePadV2]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, no_maxpool_NEG)
+{
+ struct Graph_No_MaxPool : public TestGraph
+ {
+ Graph_No_MaxPool()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 6, 8, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ auto relu_node = relu(input);
+
+ IntList paddings = {0, 0, 1, 1, 2, 2, 0, 0};
+ auto padding_const = -10.0;
+ return padV2(relu_node, paddings, padding_const);
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+
+ ASSERT_FALSE(result);
+}
+
+/**
+ * Graph where PadV2 has non-negative constant value
+ *
+ * [CircleInput]
+ * |
+ * [Relu]
+ * |
+ * [CirclePadV2]
+ * |
+ * [MaxPool2D]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, non_negative_NEG)
+{
+ struct NegGraph : public TestGraph
+ {
+ NegGraph()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 6, 6, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ constexpr auto POSITIVE_CONST_VALUE = 0.1f;
+
+ auto relu_node = relu(input);
+
+ IntList paddings = {0, 0, 1, 1, 1, 1, 0, 0};
+ auto padV2_node = padV2(relu_node, paddings, POSITIVE_CONST_VALUE);
+
+ return maxpool2d(padV2_node, {2, 2});
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+
+ ASSERT_FALSE(result);
+}
+
+/**
+ * Graph that has PadV2.padding wider than MaxPool2D.Filter
+ *
+ * [CircleInput]
+ * |
+ * [CircleRelu]
+ * |
+ * [CirclePadV2] paddings=[0, 0, 3, 3, 1, 1, 0, 0]
+ * |
+ * [CircleMaxPool2D] Filter_H = 2, Filter_W = 2 (Filter_H < paddings for H)
+ * |
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, wider_paddings_01_NEG)
+{
+ struct NegGraph : public TestGraph
+ {
+ NegGraph()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 9, 5, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ auto relu_node = relu(input);
+
+ constexpr auto TOO_WIDE_H_FRONT = 3;
+ constexpr auto TOO_WIDE_H_END = 3;
+
+ IntList paddings = {0, 0, TOO_WIDE_H_FRONT, TOO_WIDE_H_END, 1, 1, 0, 0};
+ auto padding_const = -10.0;
+ auto padv2 = padV2(relu_node, paddings, padding_const);
+
+ return maxpool2d(padv2, {2, 2});
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+
+ ASSERT_FALSE(result);
+}
+
+/**
+ * Graph that has PadV2.paddings wider than MaxPool2D.Filter
+ *
+ * Transpose ops are inserted, e.g., to switch layout between NHWC and NCHW
+ *
+ * [CircleInput]
+ * |
+ * [Relu]
+ * | 1x4x4x3 (NHWC)
+ * [Transpose] perm=[0,3,1,2]
+ * | 1x3x4x4 (NCHW)
+ * [CirclePadV2] paddings=[0,0,0,0,3,3,1,1]
+ * | 1x3x6x6 (NCHW)
+ * [Transpose] perm=[0,2,3,1]
+ * | 1x6x6x3 (NHWC)
+ * [MaxPool2D] filter.H = 2, filter.W = 2
+ * | 1x4x4x3
+ * [CircleOutput]
+ */
+TEST_F(SubstitutePadV2ToPadPassTest, wider_paddings_02_NEG)
+{
+ struct Graph_Reshaping_Op : public TestGraph
+ {
+ Graph_Reshaping_Op()
+ {
+ UIntList input_shape = {1, 4, 4, 3};
+ UIntList output_shape = {1, 9, 5, 3};
+ init(input_shape, output_shape);
+ }
+
+ loco::Node *build_body(loco::Node *input) final
+ {
+ auto relu_node = relu(input);
+
+ auto transpose1_node = transpose(relu_node, {0, 3, 1, 2});
+
+ constexpr auto TOO_WIDE_H_FRONT = 3;
+ constexpr auto TOO_WIDE_H_END = 3;
+
+ IntList paddings = {0, 0, 0, 0, TOO_WIDE_H_FRONT, TOO_WIDE_H_END, 1, 1};
+ auto padding_const = -10.0;
+ auto padV2_node = padV2(transpose1_node, paddings, padding_const);
+
+ auto transpose2_node = transpose(padV2_node, {0, 2, 3, 1});
+
+ return maxpool2d(transpose2_node, {3, 3});
+ }
+ } graph;
+
+ auto result = run_pass(graph.g());
+ ASSERT_FALSE(result);
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <bitset>
+#include <vector>
+
+/**
+ * @brief Convert strided_slice op in a certain condition to reshape op
+ * @details Convert strided_slice op if the op meets all of the following condition:
+ * For all i, 0 <= i < input.rank
+ * - begin[i] == 0
+ * - end[i] >= input.shape.dim[i]
+ * - strides[i] == 1
+ * For all k (0 <= k < input.rank) where kth bit of shrink_axis_mask == 1
+ * - end[k] == 1
+ *
+ * Example:
+ * input.shape = [1,1,2,3]
+ * strided_slice(input, begin=[0,0,0,0], end=[1,1,2,3], strides=[1,1,1,1],
+ * shrink_axis_mask=0011b) // k = 0, 1
+ *
+ * can be converted to
+ *
+ * reshape(input, [2,3])
+ */
+namespace
+{
+
+/**
+ * @brief Return newly-created CircleConst whose rank is 1
+ */
+luci::CircleConst *build_rank1_const(loco::Graph *graph, const std::vector<uint32_t> &values)
+{
+ auto const_node = graph->nodes()->create<luci::CircleConst>();
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(values.size());
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->rank(1);
+ const_node->dim(0) = values.size();
+
+ for (size_t i = 0; i < values.size(); i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = values.at(i);
+ }
+
+ return const_node;
+}
+
+/**
+ * @brief Return newly-created CircleReshape node
+ */
+luci::CircleNode *build_reshape(loco::Graph *graph, const std::string &name,
+ const std::shared_ptr<luci::CircleNodeOrigin> &origin,
+ luci::CircleNode *input, const std::vector<uint32_t> &new_shape)
+{
+ auto reshape_node = graph->nodes()->create<luci::CircleReshape>();
+ reshape_node->tensor(input);
+ reshape_node->name(name);
+ luci::add_origin(reshape_node, origin);
+
+ auto new_shape_const = build_rank1_const(graph, new_shape);
+ {
+ new_shape_const->name(name + "/new_shape");
+ luci::add_origin(new_shape_const, origin);
+ }
+
+ reshape_node->shape(new_shape_const);
+
+ return reshape_node;
+}
+
+/**
+ * @brief Return value in position on CircleConst with int64 format.
+ */
+int64_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx)
+{
+ assert(node->rank() == 1 && node->dim(0).value() > idx);
+ assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32);
+
+ if (node->dtype() == loco::DataType::S64)
+ return node->at<loco::DataType::S64>(idx);
+ return static_cast<int64_t>(node->at<loco::DataType::S32>(idx));
+}
+
+bool substitute_strided_slice_to_reshape(luci::CircleStridedSlice *ss_node)
+{
+ if (ss_node->shrink_axis_mask() == 0)
+ return false;
+
+ // TODO Consider cases with ellipsis_mask and new_axis_mask
+ // NOT YET SUPPORTED
+ if (ss_node->ellipsis_mask() != 0 or ss_node->new_axis_mask() != 0)
+ return false;
+
+ auto begin_const = dynamic_cast<luci::CircleConst *>(ss_node->begin());
+ auto strides_const = dynamic_cast<luci::CircleConst *>(ss_node->strides());
+ auto end_const = dynamic_cast<luci::CircleConst *>(ss_node->end());
+
+ if (not(begin_const && strides_const && end_const))
+ return false;
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(ss_node->input());
+
+ // condition check
+ std::bitset<32> begin_mask(ss_node->begin_mask());
+ std::bitset<32> end_mask(ss_node->end_mask());
+ std::bitset<32> shrink_axis_mask(ss_node->shrink_axis_mask());
+
+ uint input_rank = input_node->rank();
+ for (uint32_t i = 0; i < input_rank; i++)
+ {
+ if (!input_node->dim(i).known())
+ return false;
+
+ auto begin_dim = value_from_circle_const(begin_const, i);
+ if (begin_dim != 0 and begin_mask.test(i) == false)
+ return false;
+
+ // NOTE:
+ // In Tensorflow and TFLite, e.g., if input_shape = [2,3],
+ // strided_slice.end = [10,20] (larger value than actual dim)
+ // is treated as strided_slice.end = [2,3]
+ int64_t end_dim = value_from_circle_const(end_const, i);
+ if (end_dim < input_node->dim(i).value() and end_mask.test(i) == false)
+ return false;
+
+ int64_t strides_value = value_from_circle_const(strides_const, i);
+ if (strides_value != 1)
+ return false;
+
+ if (shrink_axis_mask.test(i) && input_node->dim(i).value() != 1)
+ return false;
+ }
+
+ // build shape for Reshape op
+ bool found = false;
+ std::vector<uint32_t> shrunk_shape;
+ for (uint32_t i = 0; i < input_rank; i++)
+ {
+ if (input_node->dim(i) == 1 and shrink_axis_mask.test(i))
+ found = true;
+ else
+ shrunk_shape.emplace_back(input_node->dim(i).value());
+ }
+
+ if (not found)
+ return false;
+
+ auto reshape_node = build_reshape(input_node->graph(), ss_node->name(), luci::get_origin(ss_node),
+ input_node, shrunk_shape);
+
+ replace(ss_node).with(reshape_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst] [CircleConst] [CircleConst]
+ * | | | |
+ * -------+------------------------------------
+ * |
+ * [CircleStridedSlice]
+ * |
+ * [CircleNode]
+ * |
+ * AFTER
+ * |
+ * [CircleConst] [CircleNode] [CircleConst] [CircleConst] [CircleConst]
+ * \ / \ | | |
+ * [CircleReshape] -------------------+----------------------
+ * | |
+ * [CircleNode] [CircleStridedSlice]
+ * |
+ */
+bool SubstituteStridedSliceToReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto circle_node = dynamic_cast<luci::CircleStridedSlice *>(node))
+ {
+ if (substitute_strided_slice_to_reshape(circle_node))
+ {
+ changed = true;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+luci::CircleConst *build_rank1_const(loco::Graph *graph, const std::vector<uint32_t> values)
+{
+ auto const_node = graph->nodes()->create<luci::CircleConst>();
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(values.size());
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->rank(1);
+ const_node->dim(0) = values.size();
+
+ for (int32_t i = 0; i < values.size(); i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = values.at(i);
+ }
+
+ return const_node;
+}
+
+class SubstituteStridedSliceToReshapeTest : public ::testing::Test
+{
+public:
+ SubstituteStridedSliceToReshapeTest() {}
+
+ void buildGraph(const std::initializer_list<uint32_t> input_shape,
+ const std::initializer_list<uint32_t> begin_vals,
+ const std::initializer_list<uint32_t> end_vals,
+ const std::initializer_list<uint32_t> strides_vals, int32_t begin_mask,
+ int32_t end_mask, int32_t ellipsis_mask, int32_t new_axis_mask,
+ int32_t shrink_axis_mask)
+ {
+ // Input node
+ input = g.nodes()->create<luci::CircleInput>();
+ {
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->rank(input_shape.size());
+ input->shape(input_shape);
+ input->name("input");
+ }
+
+ // StridedSlice node
+ auto ss_node = g.nodes()->create<luci::CircleStridedSlice>();
+ {
+ auto *graph = &g;
+ auto build_attr = [&graph](const std::string &name,
+ const std::initializer_list<uint32_t> vals) {
+ auto node = build_rank1_const(graph, vals);
+ node->name(name);
+
+ return node;
+ };
+
+ ss_node->input(input);
+ auto begin = build_attr("begin", begin_vals);
+ auto end = build_attr("end", end_vals);
+ auto strides = build_attr("strides", strides_vals);
+
+ ss_node->begin(begin);
+ ss_node->end(end);
+ ss_node->strides(strides);
+
+ ss_node->begin_mask(begin_mask);
+ ss_node->end_mask(end_mask);
+ ss_node->ellipsis_mask(ellipsis_mask);
+ ss_node->new_axis_mask(new_axis_mask);
+ ss_node->shrink_axis_mask(shrink_axis_mask);
+ }
+
+ // Output node
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->from(ss_node);
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+ output->name("output");
+ }
+
+ void assert_not_converted()
+ {
+ luci::SubstituteStridedSliceToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ ASSERT_TRUE(reshape_node == nullptr);
+
+ auto strided_slice_node = dynamic_cast<luci::CircleStridedSlice *>(output->from());
+ ASSERT_TRUE(strided_slice_node != nullptr);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(SubstituteStridedSliceToReshapePassTest, name)
+{
+ luci::SubstituteStridedSliceToReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, simple_case)
+{
+ buildGraph({1, 1, 5, 1, 9}, // input shape
+ {0, 0, 0, 0, 0}, // begin
+ {1, 1, 5, 1, 9}, // end
+ {1, 1, 1, 1, 1}, // strides
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b01001 // shrink axis mask, 0th and 3rd dim will be shrunk
+ );
+
+ luci::SubstituteStridedSliceToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ ASSERT_TRUE(reshape_node != nullptr);
+
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(new_shape->rank(), 1);
+ ASSERT_EQ(new_shape->dim(0).value(), 3);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(0), 1);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(1), 5);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(2), 9);
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, with_begin_end_mask)
+{
+ buildGraph({5, 1, 9}, // input shape
+ {0, 0, 5}, // begin
+ {3, 1, 9}, // end
+ {1, 1, 1}, // strides
+ 0b100, // begin mask
+ 0b001, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b010 // shrink axis mask, 0th and 3rd dim will be shrunk
+ );
+
+ luci::SubstituteStridedSliceToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ ASSERT_TRUE(reshape_node != nullptr);
+
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(new_shape->rank(), 1);
+ ASSERT_EQ(new_shape->dim(0).value(), 2);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(0), 5);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(1), 9);
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, with_large_end_mask)
+{
+ buildGraph({5, 1, 9}, // input shape
+ {0, 0, 0}, // begin
+ {100, 100, 100}, // large end
+ {1, 1, 1}, // strides
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b010 // shrink axis mask, 0th and 3rd dim will be shrunk
+ );
+
+ luci::SubstituteStridedSliceToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ ASSERT_TRUE(reshape_node != nullptr);
+
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(new_shape->rank(), 1);
+ ASSERT_EQ(new_shape->dim(0).value(), 2);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(0), 5);
+ ASSERT_EQ(new_shape->at<loco::DataType::S32>(1), 9);
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, not_matching_begin_index_NEG)
+{
+ buildGraph({1, 3, 5, 7}, // input shape
+ {0, 0, 2, 0}, // begin[2] does not start from 0
+ {1, 3, 5, 7}, // end
+ {1, 1, 1, 1}, // strides
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b0001 // shrink axis mask
+ );
+
+ assert_not_converted();
+ SUCCEED();
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, not_matching_end_index_NEG)
+{
+ buildGraph({1, 3, 5, 7}, // input shape
+ {0, 0, 0, 0}, // begin
+ {1, 3, 3, 7}, // end[2] does not meet condition
+ {1, 1, 1, 1}, // strides
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b0001 // shrink axis mask
+ );
+
+ assert_not_converted();
+ SUCCEED();
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, not_matching_strides_NEG)
+{
+ buildGraph({1, 3, 5, 7}, // input shape
+ {0, 0, 0, 0}, // begin
+ {1, 3, 5, 7}, // end
+ {1, 1, 2, 1}, // strides[2] does not meet condition
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b0001 // shrink axis mask
+ );
+
+ assert_not_converted();
+ SUCCEED();
+}
+
+TEST_F(SubstituteStridedSliceToReshapeTest, not_matching_shrink_axis_mask_NEG)
+{
+ buildGraph({1, 3, 5, 7}, // input shape
+ {0, 0, 0, 0}, // begin
+ {1, 3, 5, 7}, // end
+ {1, 1, 1, 1}, // strides
+ 0, // begin mask
+ 0, // end mask
+ 0, // ellipsis axis mask
+ 0, // new axis mask
+ 0b0101 // shrink axis mask[1] does not meet condition
+ );
+
+ assert_not_converted();
+ SUCCEED();
+}
loco::Node *mini_input = nullptr;
// There are two ways Miminum takes inputs.
- // 1. Miminum(x = CircleNode, y = CircleMinimum)
- // 2. Miminum(x = CircleMinimum, y = CircleNode)
+ // 1. Miminum(x = CircleNode, y = CircleConst)
+ // 2. Miminum(x = CircleConst, y = CircleNode)
if (not luci::fill(&mini_const, &mini_input).with_commutative_args_of(mini))
return false;
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/TransformMinReluToRelu6Pass.h"
+
+#include "helpers/NodeFiller.h"
+#include "helpers/TypeMapper.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+
+template <loco::DataType DT>
+bool is_scalar_with_value(luci::CircleConst *node, typename loco::DataTypeImpl<DT>::Type val)
+{
+ if (node->dtype() != DT)
+ return false;
+ if (node->rank() != 0)
+ return false;
+ if (node->size<DT>() != 1)
+ return false;
+ if (node->at<DT>(0) != static_cast<typename loco::DataTypeImpl<DT>::Type>(val))
+ return false;
+
+ return true;
+}
+
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [CircleMinimum]
+ * |
+ * [CircleRelu]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleRelu6]
+ * |
+ * [CircleNode]
+ *
+ * NOTE Only relu(min(input, 6)) pattern will be transformed.
+ */
+template <loco::DataType DT> bool transform_min_relu_pattern(luci::CircleRelu *relu)
+{
+ if (not relu)
+ return false;
+
+ if (relu->dtype() != DT)
+ return false;
+
+ auto *mini = dynamic_cast<luci::CircleMinimum *>(relu->features());
+ if (not mini)
+ return false;
+
+ luci::CircleConst *mini_const = nullptr;
+ loco::Node *mini_input = nullptr;
+
+ // There are two ways Miminum takes inputs.
+ // 1. Miminum(x = CircleNode, y = CircleConst)
+ // 2. Miminum(x = CircleConst, y = CircleNode)
+ if (not luci::fill(&mini_const, &mini_input).with_commutative_args_of(mini))
+ return false;
+
+ // Miminum constant should be scalar whose value is 6.
+ if (not is_scalar_with_value<DT>(mini_const,
+ static_cast<typename loco::DataTypeImpl<DT>::Type>(6)))
+ return false;
+
+ auto name = relu->name();
+ assert(name.length() > 0);
+
+ // Create Relu6 op
+ auto relu6 = mini->graph()->nodes()->create<luci::CircleRelu6>();
+ relu6->features(mini_input);
+ relu6->name(name + "/Relu6");
+ luci::add_origin(relu6, luci::composite_origin({luci::get_origin(relu), luci::get_origin(mini)}));
+
+ replace(relu).with(relu6);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool TransformMinReluToRelu6Pass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto relu = dynamic_cast<luci::CircleRelu *>(node))
+ {
+ if (transform_min_relu_pattern<loco::DataType::FLOAT32>(relu))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/TransformMinReluToRelu6Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Minimum-Relu pattern graph
+ *
+ * [CircleInput] [CircleConst]
+ * \ /
+ * [CircleMinimum]
+ * |
+ * [CircleRelu]
+ * |
+ * [CircleOutput]
+ */
+struct MinReluGraph
+{
+ loco::Graph _g;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleMinimum *_mini = nullptr;
+ luci::CircleConst *_mini_const = nullptr;
+ luci::CircleRelu *_relu = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+class TransformMinReluToRelu6PassTest : public ::testing::Test
+{
+protected:
+ virtual void SetUp()
+ {
+ const int N = 1;
+ const int H = 4;
+ const int W = 4;
+ const int C = 3;
+
+ // graph input and output
+ auto graph_input = _min_relu_g._g.inputs()->create();
+ auto graph_output = _min_relu_g._g.outputs()->create();
+
+ // CircleInput
+ _min_relu_g._input = _min_relu_g._g.nodes()->create<luci::CircleInput>();
+ _min_relu_g._input->index(graph_input->index());
+ _min_relu_g._input->shape({N, H, W, C});
+ _min_relu_g._input->dtype(loco::DataType::FLOAT32);
+ _min_relu_g._input->name("input");
+
+ // CircleConst
+ _min_relu_g._mini_const = _min_relu_g._g.nodes()->create<luci::CircleConst>();
+ _min_relu_g._mini_const->shape({}); // scalar
+ _min_relu_g._mini_const->dtype(loco::DataType::FLOAT32);
+ _min_relu_g._mini_const->size<loco::DataType::FLOAT32>(1);
+ _min_relu_g._mini_const->at<loco::DataType::FLOAT32>(0) = 6.;
+ _min_relu_g._mini_const->name("mini_const");
+
+ // CircleMinimum
+ _min_relu_g._mini = _min_relu_g._g.nodes()->create<luci::CircleMinimum>();
+ _min_relu_g._mini->x(_min_relu_g._input);
+ _min_relu_g._mini->y(_min_relu_g._mini_const);
+ _min_relu_g._mini->shape({N, H, W, C});
+ _min_relu_g._mini->dtype(loco::DataType::FLOAT32);
+ _min_relu_g._mini->name("mini");
+
+ // CircleRelu
+ _min_relu_g._relu = _min_relu_g._g.nodes()->create<luci::CircleRelu>();
+ _min_relu_g._relu->features(_min_relu_g._mini);
+ _min_relu_g._relu->shape({N, H, W, C});
+ _min_relu_g._relu->dtype(loco::DataType::FLOAT32);
+ _min_relu_g._relu->name("relu");
+
+ // CircleOutput
+ _min_relu_g._output = _min_relu_g._g.nodes()->create<luci::CircleOutput>();
+ _min_relu_g._output->index(graph_output->index());
+ _min_relu_g._output->from(_min_relu_g._relu);
+ _min_relu_g._output->shape({N, H, W, C});
+ _min_relu_g._output->dtype(loco::DataType::FLOAT32);
+ _min_relu_g._output->name("output");
+ }
+
+protected:
+ luci::TransformMinReluToRelu6Pass _pass;
+ MinReluGraph _min_relu_g;
+};
+
+} // namespace
+
+TEST_F(TransformMinReluToRelu6PassTest, name)
+{
+ auto const name = _pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+/**
+ * Optimized graph looks like below.
+ *
+ * [CircleInput]
+ * |
+ * [CircleRelu6]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(TransformMinReluToRelu6PassTest, simple_test)
+{
+ auto ret = _pass.run(&_min_relu_g._g);
+ EXPECT_TRUE(ret);
+
+ auto relu6 = dynamic_cast<luci::CircleRelu6 *>(_min_relu_g._output->from());
+ EXPECT_NE(nullptr, relu6);
+
+ auto input = dynamic_cast<luci::CircleInput *>(relu6->features());
+ EXPECT_NE(nullptr, input);
+}
+
+TEST_F(TransformMinReluToRelu6PassTest, wrong_condition_NEG)
+{
+ _min_relu_g._mini_const->at<loco::DataType::FLOAT32>(0) = 2.;
+
+ auto ret = _pass.run(&_min_relu_g._g);
+
+ EXPECT_FALSE(ret);
+}
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
return true;
}
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
return true;
}
return true;
}
+ bool visit(const luci::CirclePack *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ for (uint32_t i = 0; i < node->values_count(); i++)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
+ }
+ return true;
+ }
+
bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
return true;
}
+ bool visit(const luci::CirclePadV2 *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
+ return true;
+ }
+
+ bool visit(const luci::CircleMirrorPad *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
bool visit(const luci::CirclePRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
- RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ // Bias is optional (it can be CircleOutputExclude)
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
return true;
}
return true;
}
+ bool visit(const luci::CircleLocalResponseNormalization *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
bool visit(const luci::CircleMean *node)
{
RETURN_FALSE_UNLESS(is_lwq(node));
bool visit(const luci::CircleReshape *node)
{
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->tensor()));
+ auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
+ bool input_quantized = input->quantparam() != nullptr;
+ bool node_quantized = node->quantparam() != nullptr;
+ RETURN_FALSE_UNLESS(input_quantized == node_quantized);
+ RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
+ RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
return true;
}
return true;
}
+ bool visit(const luci::CircleResizeNearestNeighbor *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpack *node)
+ {
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpackOut *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ return true;
+ }
+
+ bool visit(const luci::CircleCast *node)
+ {
+ auto input = loco::must_cast<const luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ bool node_quantized = node->quantparam() != nullptr;
+ RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
+ RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
+ return true;
+ }
+
// TODO: Implement more Ops
bool visit(const luci::CircleNode *) { return true; }
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
return true;
}
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
return true;
}
return true;
}
+ bool visit(const luci::CirclePack *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ for (uint32_t i = 0; i < node->values_count(); i++)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
+ }
+ return true;
+ }
+
bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
return true;
}
+ bool visit(const luci::CirclePadV2 *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq(node->constant_values()))
+ return true;
+ }
+
+ bool visit(const luci::CircleMirrorPad *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
bool visit(const luci::CirclePRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node))
RETURN_FALSE_UNLESS(is_lwq(node->input()))
RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
- RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
return true;
}
return true;
}
+ bool visit(const luci::CircleLocalResponseNormalization *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
bool visit(const luci::CircleMean *node)
{
RETURN_FALSE_UNLESS(is_lwq(node))
bool visit(const luci::CircleReshape *node)
{
- RETURN_FALSE_UNLESS(is_lwq(node))
- RETURN_FALSE_UNLESS(is_lwq(node->tensor()));
+ auto input = loco::must_cast<const luci::CircleNode *>(node->tensor());
+ bool input_quantized = input->quantparam() != nullptr;
+ bool node_quantized = node->quantparam() != nullptr;
+ RETURN_FALSE_UNLESS(input_quantized == node_quantized);
+ RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node))
+ RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
return true;
}
return true;
}
+ bool visit(const luci::CircleResizeNearestNeighbor *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpack *node)
+ {
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpackOut *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ return true;
+ }
+
+ bool visit(const luci::CircleCast *node)
+ {
+ auto input = loco::must_cast<const luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ bool node_quantized = node->quantparam() != nullptr;
+ RETURN_FALSE_UNLESS(not input_quantized or is_lwq(input));
+ RETURN_FALSE_UNLESS(not node_quantized or is_lwq(node));
+ return true;
+ }
+
// TODO: Implement more Ops
bool visit(const luci::CircleNode *) { return true; }
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <cmath>
+
using Type = loco::DataType;
// This macro is undef at the end of the file
return true;
}
+ bool visit(const luci::CirclePack *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ for (uint32_t i = 0; i < node->values_count(); i++)
+ {
+ RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
+ }
+ return true;
+ }
+
bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
return true;
}
+ bool visit(const luci::CirclePadV2 *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleMirrorPad *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ return true;
+ }
+
bool visit(const luci::CirclePRelu *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
return true;
}
return true;
}
+ bool visit(const luci::CircleLocalResponseNormalization *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
bool visit(const luci::CircleMean *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
bool visit(const luci::CircleReshape *node)
{
- RETURN_FALSE_UNLESS(has_type(node, Type::S16))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
+ if (node->quantparam())
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
+ }
+ else
+ {
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
+ }
luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
if (shape != nullptr)
RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32767.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
return true;
}
bool visit(const luci::CircleSplitOut *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+
+ // SplitOut has the same qparam with the input of Split
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 32768.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
return true;
}
RETURN_FALSE_UNLESS(has_type(node, Type::S16))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
return true;
}
return true;
}
+ bool visit(const luci::CircleResizeNearestNeighbor *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpack *node)
+ {
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpackOut *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+
+ // UnpackOut has the same qparam with the input of Unpack
+ auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
+ RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+ }
+
+ bool visit(const luci::CircleCast *node)
+ {
+ auto *input = loco::must_cast<luci::CircleNode *>(node->x());
+ RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
+
+ bool input_quantized = input->quantparam() != nullptr;
+ if (input_quantized)
+ RETURN_FALSE_UNLESS(has_type(input, Type::S16))
+
+ RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
+
+ bool node_quantized = node->quantparam() != nullptr;
+ if (node_quantized)
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ return true;
+ }
+
// TODO: Implement more Ops
bool visit(const luci::CircleNode *) { return true; }
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <cmath>
+
using Type = loco::DataType;
// This macro is undef at the end of the file
return true;
}
+ bool visit(const luci::CirclePack *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ for (uint32_t i = 0; i < node->values_count(); i++)
+ {
+ RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
+ }
+ return true;
+ }
+
bool visit(const luci::CirclePad *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
return true;
}
+ bool visit(const luci::CirclePadV2 *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ RETURN_FALSE_UNLESS(has_type(node->constant_values(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleMirrorPad *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ return true;
+ }
+
bool visit(const luci::CirclePRelu *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
return true;
}
return true;
}
+ bool visit(const luci::CircleLocalResponseNormalization *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
bool visit(const luci::CircleMean *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
bool visit(const luci::CircleReshape *node)
{
- RETURN_FALSE_UNLESS(has_type(node, Type::U8))
- RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
+ if (node->quantparam())
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
+ }
+ else
+ {
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), node->dtype()))
+ }
luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
if (shape != nullptr)
RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 1.0f / 255.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 0);
return true;
}
bool visit(const luci::CircleSplitOut *node)
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+
+ // SplitOut has the same qparam with the input of Split
+ auto split = loco::must_cast<luci::CircleSplit *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(split->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == 2.0f / 256.0f);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == 128);
return true;
}
{
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
return true;
}
RETURN_FALSE_UNLESS(has_type(node, Type::U8))
RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+
+ // This checks the value of scale is an integer
+ RETURN_FALSE_UNLESS(node->quantparam());
+ RETURN_FALSE_UNLESS(std::roundf(node->quantparam()->scale[0]) == node->quantparam()->scale[0]);
return true;
}
return true;
}
+ bool visit(const luci::CircleResizeNearestNeighbor *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpack *node)
+ {
+ // node's output is the input of CircleUnpackOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleUnpackOut *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+
+ // UnpackOut has the same qparam with the input of Unpack
+ auto Unpack = loco::must_cast<luci::CircleUnpack *>(node->input());
+ auto input = loco::must_cast<luci::CircleNode *>(Unpack->value());
+ RETURN_FALSE_UNLESS(node->quantparam() && input->quantparam());
+ RETURN_FALSE_UNLESS(node->quantparam()->scale[0] == input->quantparam()->scale[0]);
+ RETURN_FALSE_UNLESS(node->quantparam()->zerop[0] == input->quantparam()->zerop[0]);
+ return true;
+ }
+
+ bool visit(const luci::CircleCast *node)
+ {
+ auto *input = loco::must_cast<luci::CircleNode *>(node->x());
+ bool input_quantized = input->quantparam() != nullptr;
+ if (input_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(input, node->in_data_type()))
+ RETURN_FALSE_UNLESS(has_type(input, Type::U8))
+ }
+
+ bool node_quantized = node->quantparam() != nullptr;
+ if (node_quantized)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->out_data_type()))
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ }
+ return true;
+ }
+
// TODO: Implement more Ops
bool visit(const luci::CircleNode *) { return true; }
QuantizationGranularity str_to_granularity(const std::string &);
-template <typename T> std::vector<T> csv_to_vector(const std::string &str)
-{
- std::vector<T> ret;
- std::istringstream is(str);
- for (T i; is >> i;)
- {
- assert(i != ',');
- ret.push_back(i);
- if (is.peek() == ',')
- is.ignore();
- }
- return ret;
-}
-
} // namespace luci
#endif // __LUCI_PASS_HELPERS_STRINGS_H__
EXPECT_THROW(luci::str_to_granularity("foo"), std::runtime_error);
}
-
-TEST(StringsTest, csv_to_vector_int32)
-{
- auto ret = luci::csv_to_vector<int32_t>("1,2,3");
- ASSERT_EQ(3, ret.size());
- ASSERT_EQ(1, ret.at(0));
- ASSERT_EQ(3, ret.at(2));
-}
target_link_libraries(luci_profile PUBLIC luci_lang)
install(TARGETS luci_profile DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
std::shared_ptr<CircleNodeOrigin>
composite_origin(const std::initializer_list<std::shared_ptr<CircleNodeOrigin>> origins)
{
- return std::make_shared<CompositeOrigin>(origins);
+ auto origin = std::make_shared<CompositeOrigin>(origins);
+
+ // For empty source, no need to create origin
+ if (origin->sources().empty())
+ return nullptr;
+
+ return origin;
}
std::shared_ptr<CircleNodeOrigin>
composite_origin(const std::vector<std::shared_ptr<CircleNodeOrigin>> &origins)
{
- return std::make_shared<CompositeOrigin>(origins);
+ auto origin = std::make_shared<CompositeOrigin>(origins);
+
+ // For empty source, no need to create origin
+ if (origin->sources().empty())
+ return nullptr;
+
+ return origin;
}
} // namespace luci
bool has_origin(const luci::CircleNode *circle_node)
{
- return circle_node->annot<CircleNodeOriginAnnotation>() != nullptr;
+ if (circle_node->annot<CircleNodeOriginAnnotation>() == nullptr)
+ return false;
+
+ assert(!circle_node->annot<CircleNodeOriginAnnotation>()->origin()->sources().empty());
+
+ return true;
}
/**
*/
void add_origin(luci::CircleNode *circle_node, const std::shared_ptr<CircleNodeOrigin> origin)
{
+ // Nothing to add
+ if (origin == nullptr)
+ return;
+
auto new_origin = composite_origin({get_origin(circle_node), origin});
circle_node->annot<CircleNodeOriginAnnotation>(nullptr);
circle_node->annot(std::make_unique<CircleNodeOriginAnnotation>(new_origin));
{
ASSERT_ANY_THROW(luci::composite_origin({}));
}
+
+TEST(LuciCircleNodeOrigin, add_null_origin_NEG)
+{
+ auto g = loco::make_graph();
+ auto add = g->nodes()->create<luci::CircleAdd>();
+
+ ASSERT_FALSE(has_origin(add));
+
+ add_origin(add, nullptr);
+
+ ASSERT_FALSE(has_origin(add));
+}
+
+TEST(LuciCircleNodeOrigin, add_empty_origin_NEG)
+{
+ auto g = loco::make_graph();
+ auto add = g->nodes()->create<luci::CircleAdd>();
+
+ ASSERT_FALSE(has_origin(add));
+
+ add_origin(add, luci::composite_origin({nullptr, nullptr}));
+
+ ASSERT_FALSE(has_origin(add));
+}
require("foder")
+require("pepper-csv2vec")
require("loco")
require("locop")
require("logo")
target_link_libraries(luci_service PUBLIC mio_circle)
target_link_libraries(luci_service PUBLIC logo_core)
target_link_libraries(luci_service PRIVATE luci_log)
+target_link_libraries(luci_service PRIVATE luci_logex)
target_link_libraries(luci_service PRIVATE nncc_common)
target_link_libraries(luci_service PRIVATE oops)
install(TARGETS luci_service DESTINATION lib)
+install(DIRECTORY include/ DESTINATION include
+ FILES_MATCHING PATTERN "*.h")
if(NOT ENABLE_TEST)
return()
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SVC_CHANGE_OUTPUTS_H__
+#define __LUCI_SVC_CHANGE_OUTPUTS_H__
+
+#include <loco/IR/Graph.h>
+
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+/**
+ * @brief Change output to nodes with string name.
+ *
+ * @note Should match existing number of nodes and all names should exist.
+ * Will throw exception if failed.
+ */
+void change_outputs(loco::Graph *, const std::vector<std::string> &);
+
+} // namespace luci
+
+#endif // __LUCI_SVC_CHANGE_OUTPUTS_H__
// loco::TensorShape visit(const luci::CirclePadV2 *node) final;
// loco::TensorShape visit(const luci::CirclePow *node) final;
// loco::TensorShape visit(const luci::CirclePRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleQuantize *node) final;
// loco::TensorShape visit(const luci::CircleRange *node) final;
// loco::TensorShape visit(const luci::CircleRank *node) final;
// loco::TensorShape visit(const luci::CircleReduceAny *node) final;
// loco::DataType visit(const luci::CircleRank *node) final;
// loco::DataType visit(const luci::CircleMul *node) final;
// loco::DataType visit(const luci::CircleOneHot *node) final;
+ // loco::DataType visit(const luci::CircleQuantize *node) final;
// loco::DataType visit(const luci::CircleReduceAny *node) final;
// loco::DataType visit(const luci::CircleReduceMax *node) final;
// loco::DataType visit(const luci::CircleReduceMin *node) final;
*/
bool validate_unique_name(luci::Module *);
+bool validate(luci::Module *);
+
} // namespace luci
#endif // __LUCI_SERVICE_VALIDATE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/ChangeOutputs.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <loco/IR/Graph.h>
+
+#include <oops/UserExn.h>
+
+#include <cassert>
+#include <iostream>
+#include <map>
+
+namespace
+{
+
+luci::CircleNode *find_by_name(loco::Graph *g, const std::string &name)
+{
+ for (auto node : loco::all_nodes(g))
+ {
+ auto cnode = loco::must_cast<luci::CircleNode *>(node);
+ if (cnode->name() == name)
+ return cnode;
+ }
+ return nullptr;
+}
+
+} // namespace
+
+namespace luci
+{
+
+void change_outputs(loco::Graph *graph, const std::vector<std::string> &new_outputs)
+{
+ if (new_outputs.size() != graph->outputs()->size())
+ {
+ throw oops::UserExn("Change outputs failed: number of outputs should be ",
+ graph->outputs()->size());
+ }
+
+ std::map<std::string, luci::CircleNode *> named_nodes;
+
+ for (auto &node_name : new_outputs)
+ {
+ auto node = find_by_name(graph, node_name);
+ if (node == nullptr)
+ {
+ throw oops::UserExn("Change outputs failed: node not found: ", node_name);
+ }
+ named_nodes[node_name] = node;
+ }
+ // just to be sure
+ assert(graph->outputs()->size() == named_nodes.size());
+
+ for (uint32_t out = 0; out < graph->outputs()->size(); ++out)
+ {
+ auto output = luci::output_node(graph, out); // output is CircleOutput
+ assert(output != nullptr);
+
+ auto node_name = new_outputs.at(out);
+ auto node = named_nodes[node_name];
+ assert(node != nullptr);
+
+ output->from(node);
+
+ // update GraphOutput shape, dtype to node
+ auto graph_out = graph->outputs()->at(out);
+ auto output_shape = std::make_unique<loco::TensorShape>();
+
+ output_shape->rank(node->rank());
+ for (uint32_t r = 0; r < node->rank(); ++r)
+ {
+ if (node->dim(r).known())
+ output_shape->dim(r).set(node->dim(r).value());
+ else
+ output_shape->dim(r).unset();
+ }
+ graph_out->shape(std::move(output_shape));
+ graph_out->dtype(node->dtype());
+ }
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/ChangeOutputs.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class Sqrt2xGraphlet
+{
+public:
+ Sqrt2xGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt1 = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt1->dtype(loco::DataType::S32);
+ _sqrt1->name("sqrt1");
+
+ _sqrt2 = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt2->dtype(loco::DataType::S32);
+ _sqrt2->name("sqrt2");
+ }
+
+public:
+ luci::CircleSqrt *sqrt1(void) const { return _sqrt1; }
+ luci::CircleSqrt *sqrt2(void) const { return _sqrt2; }
+
+protected:
+ luci::CircleSqrt *_sqrt1 = nullptr;
+ luci::CircleSqrt *_sqrt2 = nullptr;
+};
+
+class Sqrt2xGraph : public TestIOGraph, public Sqrt2xGraphlet
+{
+public:
+ Sqrt2xGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ Sqrt2xGraphlet::init(g(), shape);
+
+ _sqrt1->x(input());
+
+ _sqrt2->x(_sqrt1);
+
+ output()->from(_sqrt2);
+ }
+};
+
+} // namespace
+
+TEST(ChangeOutputsTest, change)
+{
+ Sqrt2xGraph g;
+
+ g.init({3, 3});
+
+ {
+ auto output = luci::output_node(g.g(), 0);
+ ASSERT_EQ(g.sqrt2(), output->from());
+ }
+
+ std::vector<std::string> names{"sqrt1"};
+
+ EXPECT_NO_THROW(luci::change_outputs(g.g(), names));
+
+ {
+ auto output = luci::output_node(g.g(), 0);
+ ASSERT_EQ(g.sqrt1(), output->from());
+ }
+}
+
+TEST(ChangeOutputsTest, name_not_found_NEG)
+{
+ Sqrt2xGraph g;
+
+ g.init({3, 3});
+
+ std::vector<std::string> names{"sqrt33"};
+
+ EXPECT_ANY_THROW(luci::change_outputs(g.g(), names));
+}
+
+TEST(ChangeOutputsTest, number_names_NEG)
+{
+ Sqrt2xGraph g;
+
+ g.init({3, 3});
+
+ std::vector<std::string> names{"sqrt1", "sqrt2"};
+
+ EXPECT_ANY_THROW(luci::change_outputs(g.g(), names));
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNode *node)
+{
+#define CNVISIT_GRP(GRP) \
+ { \
+ CloneNodeLet<CN::GRP> cn(_graph); \
+ auto cloned = node->accept(&cn); \
+ if (cloned != nullptr) \
+ return cloned; \
+ }
+
+ CNVISIT_GRP(ABC);
+ CNVISIT_GRP(DEF);
+ CNVISIT_GRP(GHIJ);
+ CNVISIT_GRP(KLMN);
+ CNVISIT_GRP(OPQR);
+ CNVISIT_GRP(STUV);
+ CNVISIT_GRP(WXYZ);
+
+#undef CNVISIT_GRP
+
+ return nullptr;
+}
+
+} // namespace luci
namespace luci
{
-class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
+// CloneNode-let type
+enum class CN
+{
+ ABC,
+ DEF,
+ GHIJ,
+ KLMN,
+ OPQR,
+ STUV,
+ WXYZ,
+};
+
+template <CN ct> class CloneNodeLet;
+
+template <> class CloneNodeLet<CN::ABC> final : public luci::CircleNodeVisitor<luci::CircleNode *>
{
public:
- CloneNode(loco::Graph *graph) : _graph(graph){};
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
public:
luci::CircleNode *visit(const luci::CircleAbs *) final;
luci::CircleNode *visit(const luci::CircleConv2D *) final;
luci::CircleNode *visit(const luci::CircleCos *) final;
luci::CircleNode *visit(const luci::CircleCustom *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::DEF> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleDepthToSpace *) final;
luci::CircleNode *visit(const luci::CircleDepthwiseConv2D *) final;
luci::CircleNode *visit(const luci::CircleDequantize *) final;
luci::CircleNode *visit(const luci::CircleFloorDiv *) final;
luci::CircleNode *visit(const luci::CircleFloorMod *) final;
luci::CircleNode *visit(const luci::CircleFullyConnected *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::GHIJ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleGather *) final;
luci::CircleNode *visit(const luci::CircleGatherNd *) final;
luci::CircleNode *visit(const luci::CircleGreater *) final;
luci::CircleNode *visit(const luci::CircleGreaterEqual *) final;
- // luci::CircleNode *visit(const luci::CircleIf *) final;
+ luci::CircleNode *visit(const luci::CircleIf *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::KLMN> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleL2Normalize *) final;
luci::CircleNode *visit(const luci::CircleL2Pool2D *) final;
luci::CircleNode *visit(const luci::CircleLeakyRelu *) final;
luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4 *) final;
luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5 *) final;
luci::CircleNode *visit(const luci::CircleNotEqual *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::OPQR> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleOneHot *) final;
luci::CircleNode *visit(const luci::CirclePack *) final;
luci::CircleNode *visit(const luci::CirclePad *) final;
luci::CircleNode *visit(const luci::CirclePadV2 *) final;
luci::CircleNode *visit(const luci::CirclePow *) final;
luci::CircleNode *visit(const luci::CirclePRelu *) final;
+ luci::CircleNode *visit(const luci::CircleQuantize *) final;
luci::CircleNode *visit(const luci::CircleRange *) final;
luci::CircleNode *visit(const luci::CircleRank *) final;
luci::CircleNode *visit(const luci::CircleReduceAny *) final;
luci::CircleNode *visit(const luci::CircleReverseV2 *) final;
luci::CircleNode *visit(const luci::CircleRound *) final;
luci::CircleNode *visit(const luci::CircleRsqrt *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::STUV> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleScatterNd *) final;
luci::CircleNode *visit(const luci::CircleSegmentSum *) final;
luci::CircleNode *visit(const luci::CircleSelect *) final;
luci::CircleNode *visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
luci::CircleNode *visit(const luci::CircleUnique *) final;
luci::CircleNode *visit(const luci::CircleUnpack *) final;
+
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+template <> class CloneNodeLet<CN::WXYZ> final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNodeLet(loco::Graph *graph) : _graph(graph){};
+
+public:
luci::CircleNode *visit(const luci::CircleWhere *) final;
- // luci::CircleNode *visit(const luci::CircleWhile *) final;
+ luci::CircleNode *visit(const luci::CircleWhile *) final;
luci::CircleNode *visit(const luci::CircleZerosLike *) final;
+ luci::CircleNode *visit(const luci::CircleNode *) final { return nullptr; }
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNode(loco::Graph *graph) : _graph(graph){};
+
+public:
// Circle Only
luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final;
luci::CircleNode *visit(const luci::CircleBCQGather *) final;
luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
+ // NOTE CircleInput and CircleOutput are not handled here as these need
+ // link with graph I/O
+
// Virtual
luci::CircleNode *visit(const luci::CircleCustomOut *) final;
- // luci::CircleNode *visit(const luci::CircleIfOut *) final;
+ luci::CircleNode *visit(const luci::CircleIfOut *) final;
// luci::CircleNode *visit(const luci::CircleInput *) final;
luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4Out *) final;
luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5Out *) final;
luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
- // luci::CircleNode *visit(const luci::CircleWhileOut *) final;
+ luci::CircleNode *visit(const luci::CircleWhileOut *) final;
+
+ // Handle in CircleNode
+ luci::CircleNode *visit(const luci::CircleNode *) final;
// NOTE CircleNodeVisitor will throw if not supported here
{
auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
+ assert(ifm_shape.dim(1).known());
+ assert(ifm_shape.dim(2).known());
uint32_t input_height = ifm_shape.dim(1).value();
uint32_t input_width = ifm_shape.dim(2).value();
if (node->padding() == luci::Padding::VALID)
{
+ LUCI_ASSERT(input_height + stride_height > effective_window_height, "Invalid shape");
+ LUCI_ASSERT(input_width + stride_width > effective_window_width, "Invalid shape");
output_height = (input_height + stride_height - effective_window_height) / stride_height;
output_width = (input_width + stride_width - effective_window_width) / stride_width;
}
auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
+ assert(ifm_shape.dim(1).known());
+ assert(ifm_shape.dim(2).known());
+ assert(ker_shape.dim(1).known());
+ assert(ker_shape.dim(2).known());
uint32_t input_height = ifm_shape.dim(1).value();
uint32_t input_width = ifm_shape.dim(2).value();
if (node->padding() == luci::Padding::VALID)
{
+ LUCI_ASSERT(input_height + stride_height > effective_ker_height, "Invalid shape");
+ LUCI_ASSERT(input_width + stride_width > effective_ker_width, "Invalid shape");
output_height = (input_height + stride_height - effective_ker_height) / stride_height;
output_width = (input_width + stride_width - effective_ker_width) / stride_width;
}
auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
- INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
- << ")" << std::endl;
-
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
assert(ifm_shape.dim(3) == ker_shape.dim(3));
ofm_shape.dim(2) = os.width;
ofm_shape.dim(3) = ker_shape.dim(0);
+ INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
+ << ") output(" << ofm_shape.dim(0).value() << "," << ofm_shape.dim(1).value() << ","
+ << ofm_shape.dim(2).value() << "," << ofm_shape.dim(3).value() << ") " << node->name()
+ << std::endl;
+
return loco::NodeShape{ofm_shape};
}
loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
+ loco::NodeShape visit(const luci::CircleQuantize *node) final
+ {
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
loco::NodeShape visit(const luci::CircleRank *) final
return input_type;
}
+ // TODO support S16
+ loco::DataType visit(const luci::CircleQuantize *) final { return loco::DataType::U8; }
+
loco::DataType visit(const luci::CircleRange *node) final
{
return luci::dtype_get(node->start());
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleAbs *)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAbs *)
{
return _graph->nodes()->create<luci::CircleAbs>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleAdd *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAdd *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleAddN *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAddN *node)
{
auto arity = node->arity();
return _graph->nodes()->create<luci::CircleAddN>(arity);
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleArgMax *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleArgMax *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleArgMax>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleArgMin *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleArgMin *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleArgMin>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleAveragePool2D *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleAveragePool2D *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleBatchMatMul *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleBatchMatMul *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleBatchMatMul>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleBatchToSpaceND *)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleBatchToSpaceND *)
{
return _graph->nodes()->create<luci::CircleBatchToSpaceND>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleCast *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleCast *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleCast>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleCeil *)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleCeil *)
{
return _graph->nodes()->create<luci::CircleCeil>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleConcatenation *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleConcatenation *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleConst *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleConst *node)
{
return clone_circleconst(node, _graph);
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleConv2D *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleConv2D *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleCos *)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleCos *)
{
return _graph->nodes()->create<luci::CircleCos>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleCustom *node)
+luci::CircleNode *CloneNodeLet<CN::ABC>::visit(const luci::CircleCustom *node)
{
uint32_t num_in = node->numInputs();
uint32_t num_out = node->numOutputs();
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleDepthToSpace *node)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleDepthToSpace *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleDepthToSpace>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleDepthwiseConv2D *node)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleDepthwiseConv2D *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleDequantize *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleDequantize *)
{
return _graph->nodes()->create<luci::CircleDequantize>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleDiv *node)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleDiv *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleElu *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleElu *)
{
return _graph->nodes()->create<luci::CircleElu>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleEqual *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleEqual *)
{
return _graph->nodes()->create<luci::CircleEqual>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleExp *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleExp *)
{
return _graph->nodes()->create<luci::CircleExp>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleExpandDims *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleExpandDims *)
{
return _graph->nodes()->create<luci::CircleExpandDims>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFakeQuant *node)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFakeQuant *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleFakeQuant>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFill *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFill *)
{
return _graph->nodes()->create<luci::CircleFill>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFloor *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFloor *)
{
return _graph->nodes()->create<luci::CircleFloor>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFloorDiv *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFloorDiv *)
{
return _graph->nodes()->create<luci::CircleFloorDiv>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFloorMod *)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFloorMod *)
{
return _graph->nodes()->create<luci::CircleFloorMod>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleFullyConnected *node)
+luci::CircleNode *CloneNodeLet<CN::DEF>::visit(const luci::CircleFullyConnected *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleGather *node)
+luci::CircleNode *CloneNodeLet<CN::GHIJ>::visit(const luci::CircleGather *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleGather>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleGatherNd *)
+luci::CircleNode *CloneNodeLet<CN::GHIJ>::visit(const luci::CircleGatherNd *)
{
return _graph->nodes()->create<luci::CircleGatherNd>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleGreater *)
+luci::CircleNode *CloneNodeLet<CN::GHIJ>::visit(const luci::CircleGreater *)
{
return _graph->nodes()->create<luci::CircleGreater>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleGreaterEqual *)
+luci::CircleNode *CloneNodeLet<CN::GHIJ>::visit(const luci::CircleGreaterEqual *)
{
return _graph->nodes()->create<luci::CircleGreaterEqual>();
}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNodeLet<CN::GHIJ>::visit(const luci::CircleIf *node)
+{
+ auto ic = node->input_count();
+ auto oc = node->output_count();
+
+ return _graph->nodes()->create<luci::CircleIf>(ic, oc);
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_If)
+{
+ auto g = loco::make_graph();
+ auto node_if = g->nodes()->create<luci::CircleIf>(1, 1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_if, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_if = dynamic_cast<luci::CircleIf *>(cloned);
+ ASSERT_NE(nullptr, cloned_if);
+ ASSERT_EQ(-1, cloned_if->then_branch());
+ ASSERT_EQ(-1, cloned_if->else_branch());
+ ASSERT_EQ(nullptr, cloned_if->then_graph());
+ ASSERT_EQ(nullptr, cloned_if->else_graph());
+}
#include <luci/Service/CircleShapeInference.h>
#include <luci/Service/CircleTypeInference.h>
+#include "CircleCloneNode.h"
+
namespace
{
}
} // namespace luci
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleIfOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleIfOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_IfOut)
+{
+ auto g = loco::make_graph();
+ auto node_iout = g->nodes()->create<luci::CircleIfOut>();
+ node_iout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_iout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_iout = dynamic_cast<luci::CircleIfOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_iout);
+ ASSERT_EQ(node_iout->index(), cloned_iout->index());
+}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleL2Normalize *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleL2Normalize *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleL2Pool2D *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleL2Pool2D *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLeakyRelu *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLeakyRelu *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleLeakyRelu>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLess *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLess *)
{
return _graph->nodes()->create<luci::CircleLess>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLessEqual *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLessEqual *)
{
return _graph->nodes()->create<luci::CircleLessEqual>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLocalResponseNormalization *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLocalResponseNormalization *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleLocalResponseNormalization>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLog *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLog *)
{
return _graph->nodes()->create<luci::CircleLog>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLogSoftmax *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLogSoftmax *)
{
return _graph->nodes()->create<luci::CircleLogSoftmax>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLogicalAnd *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLogicalAnd *)
{
return _graph->nodes()->create<luci::CircleLogicalAnd>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLogicalNot *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLogicalNot *)
{
return _graph->nodes()->create<luci::CircleLogicalNot>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLogicalOr *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLogicalOr *)
{
return _graph->nodes()->create<luci::CircleLogicalOr>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleLogistic *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleLogistic *)
{
return _graph->nodes()->create<luci::CircleLogistic>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMatrixDiag *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMatrixDiag *)
{
return _graph->nodes()->create<luci::CircleMatrixDiag>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMatrixSetDiag *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMatrixSetDiag *)
{
return _graph->nodes()->create<luci::CircleMatrixSetDiag>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMaxPool2D *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMaxPool2D *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMaximum *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMaximum *)
{
return _graph->nodes()->create<luci::CircleMaximum>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMean *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMean *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleMean>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMinimum *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMinimum *)
{
return _graph->nodes()->create<luci::CircleMinimum>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMirrorPad *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMirrorPad *node)
{
if (node->mode() == luci::MirrorPadMode::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleMul *node)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleMul *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleNeg *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleNeg *)
{
return _graph->nodes()->create<luci::CircleNeg>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV4 *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleNonMaxSuppressionV4 *)
{
return _graph->nodes()->create<luci::CircleNonMaxSuppressionV4>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV5 *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleNonMaxSuppressionV5 *)
{
return _graph->nodes()->create<luci::CircleNonMaxSuppressionV5>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleNotEqual *)
+luci::CircleNode *CloneNodeLet<CN::KLMN>::visit(const luci::CircleNotEqual *)
{
return _graph->nodes()->create<luci::CircleNotEqual>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleOneHot *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleOneHot *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleOneHot>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CirclePRelu *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CirclePRelu *)
{
return _graph->nodes()->create<luci::CirclePRelu>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CirclePack *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CirclePack *node)
{
auto *cloned = _graph->nodes()->create<luci::CirclePack>(node->values_count());
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CirclePad *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CirclePad *)
{
return _graph->nodes()->create<luci::CirclePad>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CirclePadV2 *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CirclePadV2 *)
{
return _graph->nodes()->create<luci::CirclePadV2>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CirclePow *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CirclePow *)
{
return _graph->nodes()->create<luci::CirclePow>();
}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleQuantize *)
+{
+ return _graph->nodes()->create<luci::CircleQuantize>();
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Quantize)
+{
+ auto g = loco::make_graph();
+ auto node_q = g->nodes()->create<luci::CircleQuantize>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_q, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_q = dynamic_cast<luci::CircleQuantize *>(cloned);
+ ASSERT_NE(nullptr, cloned_q);
+}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRange *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRange *)
{
return _graph->nodes()->create<luci::CircleRange>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRank *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRank *)
{
return _graph->nodes()->create<luci::CircleRank>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReduceAny *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReduceAny *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReduceAny>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReduceMax *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReduceMax *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReduceMax>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReduceMin *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReduceMin *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReduceMin>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReduceProd *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReduceProd *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReduceProd>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRelu *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRelu *)
{
return _graph->nodes()->create<luci::CircleRelu>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRelu6 *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRelu6 *)
{
return _graph->nodes()->create<luci::CircleRelu6>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReluN1To1 *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReluN1To1 *)
{
return _graph->nodes()->create<luci::CircleReluN1To1>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReshape *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReshape>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleResizeBilinear *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleResizeBilinear *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleResizeBilinear>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleResizeNearestNeighbor *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleResizeNearestNeighbor *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleResizeNearestNeighbor>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReverseSequence *node)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReverseSequence *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleReverseSequence>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleReverseV2 *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReverseV2 *)
{
return _graph->nodes()->create<luci::CircleReverseV2>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRound *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRound *)
{
return _graph->nodes()->create<luci::CircleRound>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleRsqrt *)
+luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleRsqrt *)
{
return _graph->nodes()->create<luci::CircleRsqrt>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleScatterNd *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleScatterNd *)
{
return _graph->nodes()->create<luci::CircleScatterNd>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSegmentSum *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSegmentSum *)
{
return _graph->nodes()->create<luci::CircleSegmentSum>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSelect *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSelect *)
{
return _graph->nodes()->create<luci::CircleSelect>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSelectV2 *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSelectV2 *)
{
return _graph->nodes()->create<luci::CircleSelectV2>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleShape *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleShape *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleShape>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSin *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSin *)
{
return _graph->nodes()->create<luci::CircleSin>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSlice *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSlice *)
{
return _graph->nodes()->create<luci::CircleSlice>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSoftmax *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSoftmax *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSoftmax>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToBatchND *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSpaceToBatchND *)
{
return _graph->nodes()->create<luci::CircleSpaceToBatchND>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToDepth *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSpaceToDepth *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSpaceToDepth>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSparseToDense *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSparseToDense *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSparseToDense>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSplit *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSplit *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSplit>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSplitV *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSplitV *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSplitV>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSqrt *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSqrt *)
{
return _graph->nodes()->create<luci::CircleSqrt>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSquare *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSquare *)
{
return _graph->nodes()->create<luci::CircleSquare>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSquaredDifference *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSquaredDifference *)
{
return _graph->nodes()->create<luci::CircleSquaredDifference>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSqueeze *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSqueeze *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSqueeze>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleStridedSlice *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleStridedSlice *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleStridedSlice>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSub *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSub *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleSum *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleSum *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleSum>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleTanh *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTanh *)
{
return _graph->nodes()->create<luci::CircleTanh>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleTile *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTile *)
{
return _graph->nodes()->create<luci::CircleTile>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleTopKV2 *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTopKV2 *)
{
return _graph->nodes()->create<luci::CircleTopKV2>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleTranspose *)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTranspose *)
{
return _graph->nodes()->create<luci::CircleTranspose>();
}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleTransposeConv *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleTransposeConv *node)
{
if (node->padding() == luci::Padding::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleUnidirectionalSequenceLSTM *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleUnidirectionalSequenceLSTM *node)
{
if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
return nullptr;
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleUnique *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleUnique *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleUnique>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleUnpack *node)
+luci::CircleNode *CloneNodeLet<CN::STUV>::visit(const luci::CircleUnpack *node)
{
auto *cloned = _graph->nodes()->create<luci::CircleUnpack>();
if (cloned != nullptr)
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleWhere *)
+luci::CircleNode *CloneNodeLet<CN::WXYZ>::visit(const luci::CircleWhere *)
{
return _graph->nodes()->create<luci::CircleWhere>();
}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNodeLet<CN::WXYZ>::visit(const luci::CircleWhile *node)
+{
+ auto ic = node->input_count();
+ auto oc = node->output_count();
+
+ return _graph->nodes()->create<luci::CircleWhile>(ic, oc);
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_While)
+{
+ auto g = loco::make_graph();
+ auto node_while = g->nodes()->create<luci::CircleWhile>(1, 1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_while, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_while = dynamic_cast<luci::CircleWhile *>(cloned);
+ ASSERT_NE(nullptr, cloned_while);
+ ASSERT_EQ(-1, cloned_while->cond_branch());
+ ASSERT_EQ(-1, cloned_while->body_branch());
+ ASSERT_EQ(nullptr, cloned_while->cond_graph());
+ ASSERT_EQ(nullptr, cloned_while->body_graph());
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleWhileOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleWhileOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_WhileOut)
+{
+ auto g = loco::make_graph();
+ auto node_iout = g->nodes()->create<luci::CircleWhileOut>();
+ node_iout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_iout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_iout = dynamic_cast<luci::CircleWhileOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_iout);
+ ASSERT_EQ(node_iout->index(), cloned_iout->index());
+}
namespace luci
{
-luci::CircleNode *CloneNode::visit(const luci::CircleZerosLike *)
+luci::CircleNode *CloneNodeLet<CN::WXYZ>::visit(const luci::CircleZerosLike *)
{
return _graph->nodes()->create<luci::CircleZerosLike>();
}
#include "luci/Service/Validate.h"
#include <luci/IR/Nodes/CircleOutput.h>
+#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Log.h>
+#include <luci/LogHelper.h>
#include <loco/IR/NodeShape.h>
#include <cassert>
#include <unordered_map>
#include <vector>
+#include <iostream>
namespace
{
assert(circle_output->from() != nullptr);
auto circle_node = loco::must_cast<luci::CircleNode *>(circle_output->from());
- // Shape and dtype validation for CiecleOutputExclude is not needed
+ // Shape and dtype validation for CircleOutputExclude is not needed
if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
continue;
return true;
}
+class VirtualNodeDetector final : public luci::CircleNodeVisitor<bool>
+{
+public:
+ VirtualNodeDetector() {}
+
+public:
+ bool visit(const luci::CircleBidirectionalSequenceLSTMOut *) final { return true; }
+ bool visit(const luci::CircleCustomOut *) final { return true; }
+ bool visit(const luci::CircleIfOut *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV4Out *) final { return true; }
+ bool visit(const luci::CircleNonMaxSuppressionV5Out *) final { return true; }
+ bool visit(const luci::CircleSplitOut *) final { return true; }
+ bool visit(const luci::CircleSplitVOut *) final { return true; }
+ bool visit(const luci::CircleTopKV2Out *) final { return true; }
+ bool visit(const luci::CircleUnpackOut *) final { return true; }
+ bool visit(const luci::CircleUniqueOut *) final { return true; }
+ bool visit(const luci::CircleWhileOut *) final { return true; }
+ bool visit(const luci::CircleOutputDummy *) final { return true; }
+ bool visit(const luci::CircleOutputExclude *) final { return true; }
+
+ // Return false by default
+ bool visit(const luci::CircleNode *) final { return false; }
+};
+
} // namespace
namespace luci
for (uint32_t n = 0; n < nodes->size(); ++n)
{
auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n));
+ // skip virtual nodes
+ VirtualNodeDetector d;
+ if (node->accept(&d))
+ continue;
+
auto name = node->name();
if (name.empty())
return false;
bool validate_unique_name(luci::Module *m)
{
+ LOGGER(l);
+
std::unordered_map<std::string, bool> names_col;
for (size_t g = 0; g < m->size(); ++g)
auto output = dynamic_cast<luci::CircleOutput *>(node);
if (output != nullptr)
continue;
+ // skip virtual nodes
+ VirtualNodeDetector d;
+ if (node->accept(&d))
+ continue;
auto name = node->name();
+ INFO(l) << "Node: " << name << ", " << (uint32_t)(node->opcode()) << std::endl;
auto it = names_col.find(name);
if (it != names_col.end())
+ {
+ INFO(l) << "validate_unique_name: found duplicate " << name << ", " << graph->name()
+ << std::endl;
return false;
+ }
names_col[name] = true;
}
+ // There can exist same tensor name between different subgraphs.
+ names_col.clear();
+ }
+
+ return true;
+}
+
+bool validate(luci::Module *module)
+{
+ LOGGER(l);
+
+ INFO(l) << "--- validate Module -----------------------------------";
+
+ for (size_t g = 0; g < module->size(); ++g)
+ {
+ auto graph = module->graph(g);
+
+ INFO(l) << luci::fmt(graph) << std::endl;
+
+ if (!validate(graph))
+ {
+ std::cerr << "ERROR: Invalid circle model" << std::endl;
+ return false;
+ }
+ if (!validate_name(graph))
+ {
+ std::cerr << "ERROR: circle model has empty name" << std::endl;
+ return false;
+ }
+ }
+
+ if (!validate_unique_name(module))
+ {
+ std::cerr << "ERROR: circle model has duplicate names" << std::endl;
+ return false;
}
return true;
public:
luci::CircleOutput *output(int idx) { return _outputs[idx]; }
+ uint32_t num_outputs(void) { return N; }
protected:
std::array<loco::GraphOutput *, N> _graph_outputs;
addread(Add_000)
addread(Add_001)
addread(Add_U8_000)
+addread(Add_STR_000)
+addread(Add_STR_001)
addread(AddN_000)
addread(ArgMax_000)
addread(ArgMax_001)
addread(PadV2_000)
addread(Pow_000)
addread(PRelu_000)
+addread(Quantize_000)
addread(Range_000)
addread(Rank_000)
addread(ReduceAny_000)
addwrite(Add_000)
addwrite(Add_001)
addwrite(Add_U8_000)
+addwrite(Add_STR_000)
+addwrite(Add_STR_001)
addwrite(AddN_000)
addwrite(ArgMax_000)
addwrite(ArgMax_001)
addwrite(PadV2_000)
addwrite(Pow_000)
addwrite(PRelu_000)
+addwrite(Quantize_000)
addwrite(Range_000)
addwrite(Rank_000)
addwrite(ReduceAny_000)
one-optimize
one-quantize
one-pack
+ one-profile
one-codegen
one-prepare-venv
+ onecc
)
foreach(ONE_COMMAND IN ITEMS ${ONE_COMMAND_FILES})
set(ONE_UTILITY_FILES
one-build.template.cfg
+ onecc.template.cfg
utils.py
+ conv_mixin_1.8.0.patch
)
foreach(ONE_UTILITY IN ITEMS ${ONE_UTILITY_FILES})
--- /dev/null
+--- a/onnx_tf/handlers/backend/conv_mixin.py
++++ b/onnx_tf/handlers/backend/conv_mixin.py
+@@ -98,7 +98,7 @@
+ depthwise = (x_rank == 4 and len(weight_shape) == 4 and group != 1 and
+ not transpose and not (None in weight_shape))
+ if depthwise and isinstance(x_shape, np.ndarray):
+- depthwise = group == x_shape[1]
++ depthwise = bool(group == x_shape[1])
+
+ if depthwise is True:
+ # Depthwise convolution.
# dummy driver for interface test
set(DUMMY_DRIVER_SRC src/dummy-compile.cpp)
set(HELP_DRIVER_SRC src/help-compile.cpp)
+set(DUMMY_PROFILE_SRC src/dummy-profile.cpp)
+set(HELP_PROFILE_SRC src/help-profile.cpp)
add_executable(dummy-compile ${DUMMY_DRIVER_SRC})
add_executable(help-compile ${HELP_DRIVER_SRC})
+add_executable(dummy-profile ${DUMMY_PROFILE_SRC})
+add_executable(help-profile ${HELP_PROFILE_SRC})
set(DUMMY_DRIVER "${CMAKE_CURRENT_BINARY_DIR}/dummy-compile")
set(HELP_DRIVER "${CMAKE_CURRENT_BINARY_DIR}/help-compile")
+set(DUMMY_PROFILE "${CMAKE_CURRENT_BINARY_DIR}/dummy-profile")
+set(HELP_PROFILE "${CMAKE_CURRENT_BINARY_DIR}/help-profile")
install(FILES ${DUMMY_DRIVER}
PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
GROUP_READ GROUP_EXECUTE
WORLD_READ WORLD_EXECUTE
DESTINATION test)
+
+install(FILES ${DUMMY_PROFILE}
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION test)
+
+install(FILES ${HELP_PROFILE}
+ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE
+ GROUP_READ GROUP_EXECUTE
+ WORLD_READ WORLD_EXECUTE
+ DESTINATION test)
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * dummy-profile only tests its interface rather than its functionality.
+ *
+ * ./dummy-profile ${INPUT_NAME}
+ * dummy-profile dummy output!!!
+ */
+
+#include <iostream>
+#include <fstream>
+#include <string>
+
+int main(int argc, char **argv)
+{
+ if (argc != 2)
+ return EXIT_FAILURE;
+
+ std::cout << "dummy-profile dummy output!!!" << std::endl;
+
+ return EXIT_SUCCESS;
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * help-profile prints dummy help message.
+ *
+ * $ ./help-profile -h
+ * HELP MESSAGE!!
+ */
+
+#include <iostream>
+#include <fstream>
+#include <string>
+
+int main(int argc, char **argv)
+{
+ if (argc != 2)
+ return EXIT_FAILURE;
+
+ std::string opt_h{"-h"};
+ std::string argv_1{argv[1]};
+
+ if (opt_h != argv_1)
+ return EXIT_FAILURE;
+
+ std::cout << "HELP MESSAGE!!" << std::endl;
+
+ return EXIT_SUCCESS;
+}
This will convert TensorFlow model (.pb) file to our circle model. You can also
directly call this command. one-import-tf invokes tf2tfliteV2.py script that
will internally use TensorFlow lite converter and then invoke tflite2circle
-converter to convert tflite model to circle model.
+converter to convert tflite model to circle model.
As tf2tfliteV2.py runs TensorFlow lite converter, you need to have TensorFlow
installed in your system. We recommand to use 2.3.0 for now.
execution.
- fuse_preactivation_batchnorm: This fuses batch normalization operators of pre-activations to Conv operators.
- fuse_activation_function: This fuses Activation function to a preceding operator.
+- fuse_mean_with_mean: This fuses two consecutive ReduceMean operations into one.
+- fuse_transpose_with_mean: This fuses ReduceMean with a preceding Transpose under certain conditions.
- make_batchnorm_gamma_positive: This makes negative gamma of batch normalization into a small positive value (1e-10).
Note that this pass can change the execution result of the model.
So, use it only when the impact is known to be acceptable.
- mute_warnings : This will turn off warning messages.
- generate_profile_data : This will turn on profiling data generation.
+- remove_fakequant : This will remove all fakequant operators.
+- remove_quantdequant : This will remove all Quantize-Dequantize sequence.
- remove_redundant_reshape : This fuses or removes redundant reshape operators.
- remove_redundant_transpose : This fuses or removes redundant transpose operators.
- remove_unnecessary_reshape : This removes unnecessary reshape operators.
normal BatchMatMul operator
- resolve_customop_matmul: This will convert Custom(MatMul) to normal MatMul
operator
+- resolve_customop_max_pool_with_argmax: This will convert Custom(MaxPoolWithArgmax)
+ to net of builtin operators.
- shuffle_weight_to_16x1float32 : This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32.
Note that it only converts weights whose row is a multiple of 16.
- substitute_pack_to_reshape : This will convert single input Pack to Reshape.
- substitute_squeeze_to_reshape : This will convert certain condition Squeeze to Reshape.
+- substitute_strided_slice_to_reshape : This will convert certain condition StridedSlice to Reshape.
- substitute_transpose_to_reshape : This will convert certain condition Transpose to Reshape.
- transform_min_max_to_relu6: This will transform Minimum-Maximum pattern to Relu6 operator.
+- transform_min_relu_to_relu6: This will transform Minimum(6)-Relu pattern to Relu6 operator.
There are options to enable multiple options at once for convenience.
- O1: fuse_bcq, fuse_instnorm, resolve_customop_add, resolve_customop_batchmatmul,
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import utils as _utils
+# TODO Find better way to suppress trackback on error
+# This suppression is applied only to `one-build`
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
def _parse_cfg(args):
config = configparser.ConfigParser()
+ config.optionxform = str
parsed = config.read(os.path.expanduser(getattr(args, 'config')))
if not parsed:
raise FileNotFoundError('Not found given configuration file')
for section in section_to_run:
driver_path = os.path.join(dir_path, _get_driver_name(section))
cmd = [driver_path, '--config', getattr(args, 'config'), '--section', section]
- with subprocess.Popen(
- cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(cmd)
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
converter_version=v1
+model_format=graph_def
[one-optimize]
input_path=inception_v3.circle
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import argparse
import copy
+import glob
import itertools
+import ntpath
import os
import subprocess
import sys
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_backends_list():
+ """
+ [one hierarchy]
+ one
+ ├── backends
+ ├── bin
+ ├── doc
+ ├── include
+ ├── lib
+ └── test
+
+ The list where `one-codegen` finds its backends
+ - `bin` folder where `one-codegen` exists
+ - `backends` folder
+
+ NOTE If there are backends of the same name in different places,
+ the closer to the top in the list, the higher the priority.
+ """
dir_path = os.path.dirname(os.path.realpath(__file__))
- files = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
+ backend_set = set()
+
+ # bin folder
+ files = [f for f in glob.glob(dir_path + '/*-compile')]
+ # backends folder
+ files += [
+ f for f in glob.glob(dir_path + '/../backends/**/*-compile', recursive=True)
+ ]
+ # TODO find backends in `$PATH`
+
backends_list = []
for cand in files:
- if cand.endswith('-compile'):
- # 8 : length of '-compile'
- backends_list.append(cand[:-8])
+ base = ntpath.basename(cand)
+ if not base in backend_set and os.path.isfile(cand) and os.access(cand, os.X_OK):
+ backend_set.add(base)
+ backends_list.append(cand)
+
return backends_list
-def _get_parser():
+def _get_parser(backends_list):
codegen_usage = 'one-codegen [-h] [-v] [-C CONFIG] [-b BACKEND] [--] [COMMANDS FOR BACKEND]'
parser = argparse.ArgumentParser(
description='command line tool for code generation', usage=codegen_usage)
_utils._add_default_arg(parser)
# get backend list in the directory
- backends_list = _get_backends_list()
- if not backends_list:
- backends_list_message = '(There is no available backend drivers)'
+ backends_name = [ntpath.basename(f) for f in backends_list]
+ if not backends_name:
+ backends_name_message = '(There is no available backend drivers)'
else:
- backends_list_message = '(available backend drivers: ' + '.'.join(
- backends_list) + ')'
- backend_help_message = 'backend name to use ' + backends_list_message
+ backends_name_message = '(available backend drivers: ' + ', '.join(
+ backends_name) + ')'
+ backend_help_message = 'backend name to use ' + backends_name_message
parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
return parser
def main():
+ # get backend list
+ backends_list = _get_backends_list()
+
# parse arguments
- parser = _get_parser()
+ parser = _get_parser(backends_list)
args, backend_args, unknown_args = _parse_arg(parser)
# parse configuration file
_verify_arg(parser, args)
# make a command to run given backend driver
- dir_path = os.path.dirname(os.path.realpath(__file__))
- codegen_path = os.path.join(dir_path, getattr(args, 'backend') + '-compile')
+ codegen_path = None
+ backend_base = getattr(args, 'backend') + '-compile'
+ for cand in backends_list:
+ if ntpath.basename(cand) == backend_base:
+ codegen_path = cand
+ if not codegen_path:
+ raise FileNotFoundError(backend_base + ' not found')
codegen_cmd = [codegen_path] + backend_args + unknown_args
if _utils._is_valid_attr(args, 'command'):
codegen_cmd += getattr(args, 'command').split()
# run backend driver
- with subprocess.Popen(
- codegen_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(codegen_cmd, err_prefix=backend_base)
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
# driver
parser.add_argument(
- 'driver', type=str, help='driver name to run (supported: tf, tflite, bcq)')
+ 'driver', type=str, help='driver name to run (supported: tf, tflite,' \
+ ' bcq, onnx)')
# version
dir_path = os.path.dirname(os.path.realpath(__file__))
return {
'bcq': 'one-import-bcq',
'tf': 'one-import-tf',
- 'tflite': 'one-import-tflite'
+ 'tflite': 'one-import-tflite',
+ 'onnx': 'one-import-onnx',
}[driver_name]
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import utils as _utils
import generate_bcq_output_arrays as _bcq_info_gen
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
f.write((' '.join(generate_bcq_metadata_cmd) + '\n').encode())
# generate BCQ information metadata
- with subprocess.Popen(
- generate_bcq_metadata_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(generate_bcq_metadata_cmd, logfile=f)
# get output_arrays with BCQ
bcq_output_arrays = _bcq_info_gen.get_bcq_output_arrays(
f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
# convert tf to tflite
- with subprocess.Popen(
- tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tf2tfliteV2_cmd, logfile=f)
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
# convert tflite to circle
- with subprocess.Popen(
- tflite2circle_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tflite2circle_cmd, logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
tf2tfliteV2_group.add_argument('--model_format', default='saved_model')
tf2tfliteV2_group.add_argument('--converter_version', default='v2')
+ # save intermediate file(s)
+ parser.add_argument(
+ '--save_intermediate',
+ action='store_true',
+ help='Save intermediate files to output folder')
+
return parser
return args
+def _apply_verbosity(verbosity):
+ # NOTE
+ # TF_CPP_MIN_LOG_LEVEL
+ # 0 : INFO + WARNING + ERROR + FATAL
+ # 1 : WARNING + ERROR + FATAL
+ # 2 : ERROR + FATAL
+ # 3 : FATAL
+ if verbosity:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
+ else:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+
+
def _convert(args):
+ _apply_verbosity(args.verbose)
+
# get file path to log
dir_path = os.path.dirname(os.path.realpath(__file__))
logfile_path = os.path.realpath(args.output_path) + '.log'
with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
+ # save intermediate
+ if _utils._is_valid_attr(args, 'save_intermediate'):
+ tmpdir = os.path.dirname(logfile_path)
# convert onnx to tf saved model
onnx_model = onnx.load(getattr(args, 'input_path'))
tf_savedmodel = onnx_tf.backend.prepare(onnx_model)
f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
# convert tf to tflite
- with subprocess.Popen(
- tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tf2tfliteV2_cmd, logfile=f)
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
# convert tflite to circle
- with subprocess.Popen(
- tflite2circle_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
type=str,
help='names of the output arrays, comma-separated')
+ # save intermediate file(s)
+ parser.add_argument(
+ '--save_intermediate',
+ action='store_true',
+ help='Save intermediate files to output folder')
+
return parser
logfile_path = os.path.realpath(args.output_path) + '.log'
with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
+ # save intermediate
+ if _utils._is_valid_attr(args, 'save_intermediate'):
+ tmpdir = os.path.dirname(logfile_path)
# make a command to convert from tf to tflite
tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
tf2tfliteV2_output_path = os.path.join(
f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
# convert tf to tflite
- with subprocess.Popen(
- tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tf2tfliteV2_cmd, logfile=f)
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
# convert tflite to circle
- with subprocess.Popen(
- tflite2circle_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import os
import subprocess
import sys
-import tempfile
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
# convert tflite to circle
- with subprocess.Popen(
- tflite2circle_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import os
import subprocess
import sys
-import tempfile
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
action='store_true',
help='generate profiling data')
+ utility_group.add_argument(
+ '--change_outputs',
+ type=str,
+ help='Experimental: Change first subgraph output nodes to CSV names')
+
## circle2circle arguments
circle2circle_group = parser.add_argument_group('arguments for optimization')
getattr(args, 'input_path'),
getattr(args, 'output_path'))
+ # verbose
+ if _utils._is_valid_attr(args, 'verbose'):
+ circle2circle_cmd.append('--verbose')
+ if _utils._is_valid_attr(args, 'change_outputs'):
+ circle2circle_cmd.append('--change_outputs')
+ circle2circle_cmd.append(getattr(args, 'change_outputs'))
+
f.write((' '.join(circle2circle_cmd) + '\n').encode())
# optimize
- with subprocess.Popen(
- circle2circle_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(circle2circle_cmd, err_prefix="circle2circle", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
f.write((' '.join(model2nnpkg_cmd) + '\n').encode())
# convert tflite to circle
- with subprocess.Popen(
- model2nnpkg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(model2nnpkg_cmd, err_prefix="model2nnpkg.sh", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
DRIVER_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV_ACTIVATE=${DRIVER_PATH}/venv/bin/activate
-
-function error_no_ensurepip ()
-{
- echo "ERROR: python3 'ensurepip' module is not found."
- echo " On ubuntu, try following command:"
- echo
- echo " apt install python$(python3 --version | awk '{print $2}' | awk -F. '{print $1"."$2}')-venv"
- echo
- echo " You may need root privilege for this."
- exit 1
-}
+# NOTE please use venv's python instead of python after `source activation`.
+# This script is called by debian maintainer script, i.e. `postinst`.
+# Since debian maintainer script is called with sudo, `source activation` is ignored.
+VENV_PYTHON=${DRIVER_PATH}/venv/bin/python
if [ ! -f ${VENV_ACTIVATE} ]; then
- # Install prerequisites
- python3 -m ensurepip --version > /dev/null 2>&1 || error_no_ensurepip
- python3 -m pip install --user -U virtualenv
# Create python virtual enviornment
python3 -m venv "${DRIVER_PATH}/venv"
fi
+# NOTE version
+# - https://github.com/onnx/onnx/blob/master/docs/Versioning.md
+# - https://github.com/onnx/onnx-tensorflow/blob/master/Versioning.md
+
+VER_TENSORFLOW=2.3.0
+VER_ONNX=1.8.0
+VER_ONNX_TF=1.8.0
+
# Install tensorflow
-source "${VENV_ACTIVATE}"
+
+PIP_TRUSTED_HOST="--trusted-host pypi.org "
+PIP_TRUSTED_HOST+="--trusted-host files.pythonhost.org "
+PIP_TRUSTED_HOST+="--trusted-host download.pytorch.org "
+
+PIP_TIMEOUT="--default-timeout=1000 "
+
+PIP_OPTIONS="${PIP_TIMEOUT} ${PIP_TRUSTED_HOST}"
+
+# NOTE $ONE_PREPVENV_PIP_OPTION is to provide additional PIP options
+# such as ceritificate file behind firewall
+# ex) ONE_PREPVENV_PIP_OPTION="--cert SomePrivateCetificate.crt" ./one-prepare-venv
+if [[ ! -z "$ONE_PREPVENV_PIP_OPTION" ]]; then
+ PIP_OPTIONS+=" ${ONE_PREPVENV_PIP_OPTION} "
+fi
# TODO remove version number of 'pip==20.2.1 setuptools==49.3.0'
# NOTE adding version is for temporary hotfix of setuptools 50.x.y version
-python -m pip --default-timeout=1000 --trusted-host pypi.org --trusted-host files.pythonhost.org \
- install -U pip==20.2.1 setuptools==49.3.0
-python -m pip --default-timeout=1000 --trusted-host pypi.org --trusted-host files.pythonhost.org \
- install tensorflow-cpu==2.3.0
-python -m pip --default-timeout=1000 --trusted-host pypi.org --trusted-host files.pythonhost.org \
- install Pillow==6.2.2
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install -U pip==20.2.1 setuptools==49.3.0
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install tensorflow-cpu==${VER_TENSORFLOW}
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install Pillow==6.2.2
# Install PyTorch and ONNX related
-python -m pip --default-timeout=1000 --trusted-host pypi.org --trusted-host files.pythonhost.org \
- --trusted-host download.pytorch.org \
- install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+${VENV_PYTHON} -m pip ${PIP_OPTIONS} install torch==1.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
-# NOTE Latest onnx 1.8.1 has compatibility issue with onnx-tf 1.7.0
-# MUST install with onnx==1.8.0
# Provide install of custom onnx-tf
if [ -n "${EXT_ONNX_TF_WHL}" ]; then
- python -m pip --default-timeout=1000 install onnx==1.8.0 ${EXT_ONNX_TF_WHL}
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx==${VER_ONNX} ${EXT_ONNX_TF_WHL}
else
- python -m pip --default-timeout=1000 --trusted-host pypi.org --trusted-host files.pythonhost.org \
- install onnx==1.8.0 onnx-tf==1.7.0
+ ${VENV_PYTHON} -m pip ${PIP_OPTIONS} install onnx==${VER_ONNX} onnx-tf==${VER_ONNX_TF}
fi
-# Create python symoblic link
-rm -f ${DRIVER_PATH}/python
-ln -s venv/bin/python ${DRIVER_PATH}/python
+# TODO remove this patch after onnx-tf next release
+# apply patch for DWConv conversion bug: https://github.com/onnx/onnx-tensorflow/pull/905
+if [[ -z "${EXT_ONNX_TF_WHL}" ]]; then
+ PY_SITE_PACKAGES=$(${VENV_PYTHON} -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])')
+ if [[ -d ${PY_SITE_PACKAGES} ]]; then
+ pushd ${PY_SITE_PACKAGES} > /dev/null
+ PATCH_TARGET_FILE=onnx_tf/handlers/backend/conv_mixin.py
+ if [[ -f "${PATCH_TARGET_FILE}" ]]; then
+ # if patch is already applied, error code is 1
+ # catch error code and check if this is the case
+ set +e
+ patch -t -N -p1 < ${DRIVER_PATH}/conv_mixin_1.8.0.patch
+ ret_code=$?
+ [[ $ret_code -gt 1 ]] && exit $ret_code
+ set -e
+ fi
+ popd > /dev/null
+ fi
+fi
--- /dev/null
+#!/usr/bin/env bash
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
+''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
+''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
+''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
+''''exit 255 # '''
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import glob
+import itertools
+import ntpath
+import os
+import subprocess
+import sys
+import tempfile
+
+import utils as _utils
+
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
+
+def _get_backends_list():
+ """
+ [one hierarchy]
+ one
+ ├── backends
+ ├── bin
+ ├── doc
+ ├── include
+ ├── lib
+ └── test
+
+ The list where `one-profile` finds its backends
+ - `bin` folder where `one-profile` exists
+ - `backends` folder
+
+ NOTE If there are backends of the same name in different places,
+ the closer to the top in the list, the higher the priority.
+ """
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ backend_set = set()
+
+ # bin folder
+ files = [f for f in glob.glob(dir_path + '/*-profile')]
+ # backends folder
+ files += [
+ f for f in glob.glob(dir_path + '/../backends/**/*-profile', recursive=True)
+ ]
+ # TODO find backends in `$PATH`
+
+ backends_list = []
+ for cand in files:
+ base = ntpath.basename(cand)
+ if not base in backend_set and os.path.isfile(cand) and os.access(cand, os.X_OK):
+ backend_set.add(base)
+ backends_list.append(cand)
+
+ return backends_list
+
+
+def _get_parser(backends_list):
+ profile_usage = 'one-profile [-h] [-v] [-C CONFIG] [-b BACKEND] [--] [COMMANDS FOR BACKEND]'
+ parser = argparse.ArgumentParser(
+ description='command line tool for profiling backend model', usage=profile_usage)
+
+ _utils._add_default_arg(parser)
+
+ # get backend list in the directory
+ backends_name = [ntpath.basename(f) for f in backends_list]
+ if not backends_name:
+ backends_name_message = '(There is no available backend drivers)'
+ else:
+ backends_name_message = '(available backend drivers: ' + ', '.join(
+ backends_name) + ')'
+ backend_help_message = 'backend name to use ' + backends_name_message
+ parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
+
+ return parser
+
+
+def _verify_arg(parser, args):
+ """verify given arguments"""
+ # check if required arguments is given
+ missing = []
+ if not _utils._is_valid_attr(args, 'backend'):
+ missing.append('-b/--backend')
+ if len(missing):
+ parser.error('the following arguments are required: ' + ' '.join(missing))
+
+
+def _parse_arg(parser):
+ profile_args = []
+ backend_args = []
+ unknown_args = []
+ argv = copy.deepcopy(sys.argv)
+ # delete file name
+ del argv[0]
+ # split by '--'
+ args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
+ # one-profile has two interfaces
+ # 1. one-profile [-h] [-v] [-C CONFIG] [-b BACKEND] [COMMANDS FOR BACKEND]
+ if len(args) == 1:
+ profile_args = args[0]
+ profile_args, unknown_args = parser.parse_known_args(profile_args)
+ # 2. one-profile [-h] [-v] [-C CONFIG] [-b BACKEND] -- [COMMANDS FOR BACKEND]
+ if len(args) == 2:
+ profile_args = args[0]
+ backend_args = args[1]
+ profile_args = parser.parse_args(profile_args)
+ # print version
+ if len(args) and profile_args.version:
+ _utils._print_version_and_exit(__file__)
+
+ return profile_args, backend_args, unknown_args
+
+
+def main():
+ # get backend list
+ backends_list = _get_backends_list()
+
+ # parse arguments
+ parser = _get_parser(backends_list)
+ args, backend_args, unknown_args = _parse_arg(parser)
+
+ # parse configuration file
+ _utils._parse_cfg(args, 'one-profile')
+
+ # verify arguments
+ _verify_arg(parser, args)
+
+ # make a command to run given backend driver
+ profile_path = None
+ backend_base = getattr(args, 'backend') + '-profile'
+ for cand in backends_list:
+ if ntpath.basename(cand) == backend_base:
+ profile_path = cand
+ if not profile_path:
+ raise FileNotFoundError(backend_base + ' not found')
+ profile_cmd = [profile_path] + backend_args + unknown_args
+ if _utils._is_valid_attr(args, 'command'):
+ profile_cmd += getattr(args, 'command').split()
+
+ # run backend driver
+ with subprocess.Popen(
+ profile_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ bufsize=1) as p:
+ for line in p.stdout:
+ sys.stdout.buffer.write(line)
+ sys.stdout.buffer.flush()
+ if p.returncode != 0:
+ sys.exit(p.returncode)
+
+
+if __name__ == '__main__':
+ main()
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
import utils as _utils
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
def _get_parser():
parser = argparse.ArgumentParser(
'full filepath of the input data file. if not specified, run with random input data.'
)
parser.add_argument(
+ '-f',
+ '--input_data_format',
+ type=str,
+ help=
+ 'file format of input data. h5/hdf5 (default), list/filelist (a text file where a file path of input data is written in each line), or dir/directory (a directory where input data are saved)'
+ )
+ parser.add_argument(
'-o', '--output_path', type=str, help='full filepath of the output file')
# argument for profiling
## make a command to quantize and dequantize the weights of the model
circle_quantizer_cmd = [circle_quantizer_path]
+ # verbose
+ if _utils._is_valid_attr(args, 'verbose'):
+ circle_quantizer_cmd.append('--verbose')
# quantize_dequantize_weights
circle_quantizer_cmd.append('--quantize_dequantize_weights')
if _utils._is_valid_attr(args, 'input_dtype'):
f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
# run circle-quantizer
- with subprocess.Popen(
- circle_quantizer_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
## make a command to record min-max value of each tensor while running the representative dataset
circle_record_minmax_cmd = [record_minmax_path]
+ # verbose
+ if _utils._is_valid_attr(args, 'verbose'):
+ circle_record_minmax_cmd.append('--verbose')
# input and output path
circle_record_minmax_cmd.append('--input_model')
circle_record_minmax_cmd.append(tmp_output_path_1)
if _utils._is_valid_attr(args, 'input_data'):
circle_record_minmax_cmd.append('--input_data')
circle_record_minmax_cmd.append(getattr(args, 'input_data'))
+ if _utils._is_valid_attr(args, 'input_data_format'):
+ circle_record_minmax_cmd.append('--input_data_format')
+ circle_record_minmax_cmd.append(getattr(args, 'input_data_format'))
# min and max percentile
if _utils._is_valid_attr(args, 'min_percentile'):
circle_record_minmax_cmd.append('--min_percentile')
f.write((' '.join(circle_record_minmax_cmd) + '\n').encode())
# run record-minmax
- with subprocess.Popen(
- circle_record_minmax_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(circle_record_minmax_cmd, err_prefix="record_minmax", logfile=f)
## make a second command to quantize the model using the embedded information
circle_quantizer_cmd = [circle_quantizer_path]
+ # verbose
+ if _utils._is_valid_attr(args, 'verbose'):
+ circle_quantizer_cmd.append('--verbose')
# quantize_dequantize_weights
circle_quantizer_cmd.append('--quantize_with_minmax')
if _utils._is_valid_attr(args, 'input_dtype'):
f.write((' '.join(circle_quantizer_cmd) + '\n').encode())
# run circle-quantizer
- with subprocess.Popen(
- circle_quantizer_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
- if p.returncode != 0:
- sys.exit(p.returncode)
+ _utils._run(circle_quantizer_cmd, err_prefix="circle_quantizer", logfile=f)
def main():
if __name__ == '__main__':
- main()
+ _utils._safemain(main, __file__)
--- /dev/null
+#!/usr/bin/env bash
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
+''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
+''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
+''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
+''''exit 255 # '''
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import configparser
+import os
+import subprocess
+import sys
+
+import utils as _utils
+
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
+subtool_list = {
+ 'compile': {
+ 'import': 'Convert given model to circle',
+ 'optimize': 'Optimize circle model',
+ 'quantize': 'Quantize circle model',
+ },
+ 'package': {
+ 'pack': 'Package circle and metadata into nnpackage',
+ },
+ 'backend': {
+ 'codegen': 'Code generation tool',
+ 'profile': 'Profile backend model file',
+ },
+}
+
+
+def _call_driver(driver_name, options):
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ driver_path = os.path.join(dir_path, driver_name)
+ cmd = [driver_path] + options
+ _utils._run(cmd)
+
+
+def _check_subtool_exists():
+ """verify given arguments"""
+ subtool_keys = [n for k, v in subtool_list.items() for n in v.keys()]
+ if len(sys.argv) > 1 and sys.argv[1] in subtool_keys:
+ driver_name = 'one-' + sys.argv[1]
+ options = sys.argv[2:]
+ _call_driver(driver_name, options)
+ sys.exit(0)
+
+
+def _get_parser():
+ onecc_usage = 'onecc [-h] [-v] [-C CONFIG] [COMMAND <args>]'
+ onecc_desc = 'Run ONE driver via several commands or configuration file'
+ parser = argparse.ArgumentParser(description=onecc_desc, usage=onecc_usage)
+
+ _utils._add_default_arg(parser)
+
+ # just for help message
+ compile_group = parser.add_argument_group('compile to circle model')
+ for tool, desc in subtool_list['compile'].items():
+ compile_group.add_argument(tool, action='store_true', help=desc)
+
+ package_group = parser.add_argument_group('package circle model')
+ for tool, desc in subtool_list['package'].items():
+ package_group.add_argument(tool, action='store_true', help=desc)
+
+ backend_group = parser.add_argument_group('run backend tools')
+ for tool, desc in subtool_list['backend'].items():
+ backend_group.add_argument(tool, action='store_true', help=desc)
+
+ return parser
+
+
+def _parse_arg(parser):
+ args = parser.parse_args()
+ # print version
+ if args.version:
+ _utils._print_version_and_exit(__file__)
+
+ return args
+
+
+def _verify_arg(parser, args):
+ """verify given arguments"""
+ # check if required arguments is given
+ if not _utils._is_valid_attr(args, 'config'):
+ parser.error('-C/--config argument is required')
+
+
+def _get_driver_name(driver_name):
+ return {
+ 'one-import-bcq': 'one-import-bcq',
+ 'one-import-tf': 'one-import-tf',
+ 'one-import-tflite': 'one-import-tflite',
+ 'one-import-onnx': 'one-import-onnx',
+ 'one-optimize': 'one-optimize',
+ 'one-quantize': 'one-quantize',
+ 'one-pack': 'one-pack',
+ 'one-codegen': 'one-codegen',
+ 'one-profile': 'one-profile'
+ }[driver_name]
+
+
+def _parse_cfg(args):
+ config = configparser.ConfigParser()
+ config.optionxform = str
+ parsed = config.read(os.path.expanduser(getattr(args, 'config')))
+ if not parsed:
+ raise FileNotFoundError('Not found given configuration file')
+ return config
+
+
+def _is_available_driver(config, driver_name):
+ return config.has_option('onecc', driver_name) and config.getboolean(
+ 'onecc', driver_name)
+
+
+def _verify_cfg(driver_list, config):
+ if not config.has_section('onecc'):
+ raise ImportError('[onecc] section is required in configuration file')
+
+ import_driver_cnt = 0
+ if _is_available_driver(config, 'one-import-tf'):
+ import_driver_cnt += 1
+ if _is_available_driver(config, 'one-import-tflite'):
+ import_driver_cnt += 1
+ if _is_available_driver(config, 'one-import-bcq'):
+ import_driver_cnt += 1
+ if _is_available_driver(config, 'one-import-onnx'):
+ import_driver_cnt += 1
+ if import_driver_cnt > 1:
+ raise AssertionError('Only one import-* driver can be executed')
+
+
+def main():
+ # check if there is subtool argument
+ # if true, it executes subtool with argv
+ # NOTE:
+ # Why call subtool directly without using Argparse?
+ # Because if Argparse is used, options equivalent to onecc including
+ # '--help', '-C' are processed directly onecc itself.
+ # So options cannot be delivered to subtool.
+ _check_subtool_exists()
+
+ # parse arguments
+ # since the configuration file path is required first,
+ # parsing of the configuration file proceeds after this.
+ parser = _get_parser()
+ args = _parse_arg(parser)
+
+ # verify arguments
+ _verify_arg(parser, args)
+
+ # parse configuration file
+ config = _parse_cfg(args)
+
+ # verify configuration file
+ drivers = [
+ 'one-import-tf', 'one-import-tflite', 'one-import-bcq', 'one-import-onnx',
+ 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile'
+ ]
+ _verify_cfg(drivers, config)
+
+ # get sections to run
+ section_to_run = []
+ for d in drivers:
+ if _is_available_driver(config, d):
+ section_to_run.append(d)
+
+ # run
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ for section in section_to_run:
+ driver_name = _get_driver_name(section)
+ options = ['--config', getattr(args, 'config'), '--section', section]
+ if _utils._is_valid_attr(args, 'verbose'):
+ options.append('--verbose')
+ _call_driver(driver_name, options)
+
+
+if __name__ == '__main__':
+ _utils._safemain(main, __file__)
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=False
+one-optimize=True
+one-quantize=False
+one-pack=True
+one-codegen=False
+one-profile=False
+
+[one-import-tf]
+input_path=/path/to/inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+model_format=graph_def
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+generate_profile_data=False
+
+[one-pack]
+input_path=inception_v3.opt.circle
+output_path=inception_v3_pack
outputfile="inception_v3.opt.circle"
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
outputfile="inception_v3_pkg"
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
rm -rf ${outputfile}
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
rm -rf ${outputfile}
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-build -C ${configfile} > /dev/null
+one-build -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
--- /dev/null
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.alt.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+save_intermediate=True
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf: intermediate file should exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_010.cfg"
+outputfile="inception_v3.alt.circle"
+intermfile="inception_v3.alt.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+one-build -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+if [[ ! -s "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-build]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
+save_intermediate=True
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-onnx
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_011.cfg"
+outputfile="test_onnx_model.circle"
+intermfile="test_onnx_model.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+one-build -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+if [[ ! -s "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=True
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-quantize]
+input_path=inception_v3.circle
+output_path=inception_v3.list.quantized.circle
+input_data=datalist.txt
+input_data_format=list
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-quantize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_012.cfg"
+outputfile="inception_v3.list.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test
+one-build -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=True
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-quantize]
+input_path=inception_v3.circle
+output_path=inception_v3.dir.quantized.circle
+input_data=raw_files
+input_data_format=directory
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-quantize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_013.cfg"
+outputfile="inception_v3.dir.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test
+one-build -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.alt.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative one-import-tf intermediate file should not exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_neg_005.cfg"
+outputfile="inception_v3.alt.circle"
+intermfile="inception_v3.alt.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+one-build -C ${configfile} > ${filename}.log 2>&1
+
+# output should exist
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+# intermediate file should not exist
+if [[ -f "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-build]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative one-import-tf intermediate file should not exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-build_neg_006.cfg"
+outputfile="test_onnx_model.circle"
+intermfile="test_onnx_model.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+one-build -C ${configfile} > ${filename}.log 2>&1
+
+# output should exist
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+# intermediate file should not exist
+if [[ -f "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
cp help-compile ../bin/help-compile
# run test
-one-codegen -b help -- -h > ${filename}.log
+one-codegen -b help -- -h > ${filename}.log 2>&1
rm -rf ../bin/help-compile
cp dummy-compile ../bin/dummy-compile
# run test
-one-codegen -b dummy -o ${outputfile} "dummy.circle"
+one-codegen -b dummy -o ${outputfile} "dummy.circle" > ${filename}.log 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
cp dummy-compile ../bin/dummy-compile
# run test
-one-codegen -b dummy -- -o ${outputfile} "dummy.circle"
+one-codegen -b dummy -- -o ${outputfile} "dummy.circle" > ${filename}.log 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
trap trap_err_onexit ERR
# run test
-one-codegen -h > ${filename}.log
+one-codegen -h > ${filename}.log 2>&1
if grep -q "command line tool for code generation" "${filename}.log"; then
echo "${filename_ext} SUCCESS"
outputfile="./bcq.circle"
rm -rf $outputfile
-rm -rf $outputfile.log
# run test
one-import-bcq \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder \
---output_arrays MatMul >> /dev/null
+--output_arrays MatMul > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder_null \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder \
---output_arrays MatMul_null > ${filename}.log
+--output_arrays MatMul_null > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder --input_shapes "1,32,32" \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder --input_shapes "30,30" \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder --input_shapes "32,O" \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Placeholder --input_shapes "32,32:1" \
---output_arrays MatMul > ${filename}.log
+--output_arrays MatMul > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 >> /dev/null
+--output_arrays InceptionV3/Predictions/Reshape_1 > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
# run test
one-import tf -C ${configfile} \
---output_path=inception_v3_cmd.circle > /dev/null
+--output_path=inception_v3_cmd.circle > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
rm -f ${outputfile}
# run test
-one-import tf -C ${configfile} > /dev/null
+one-import tf -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
rm -f ${outputfile}
# run test
-one-import tf -C ${configfile} > /dev/null
+one-import tf -C ${configfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
# See the License for the specific language governing permissions and
# limitations under the License.
+# import onnx model with cfg file
+
filename_ext="$(basename -- $0)"
filename="${filename_ext%.*}"
rm -f ${outputfile}
# run test
-one-build -C ${configfile} > ${filename}.log 2>&1
+one-import onnx -C ${configfile} > ${filename}.log 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# import onnx model
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="test_onnx_model.onnx"
+outputfile="test_onnx_model.circle"
+
+rm -f ${outputfile}
+
+# run test
+one-import onnx -i ${inputfile} -o ${outputfile} > ${filename}.log 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./while_3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays Hole,Hole_2 --input_shapes "1,1:1,1" \
---output_arrays Output > ${filename}.log
+--output_arrays Output > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
trap_err_onexit()
{
+ # TODO Error message depends on TF version. Find better way.
+ # TF 2.3.0
if grep -q "ValueError: Invalid tensors" "${filename}.log"; then
echo "${filename_ext} SUCCESS"
exit 0
fi
+ # TF 2.5.0
+ if grep -q "ConverterError: <unknown>:0: error:" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
echo "${filename_ext} FAILED"
exit 255
}
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_2 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_2 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299,1" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "0,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "None,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input,InceptionV3/Predictions/Shape --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
inputfile="./inception_v3.pb"
outputfile="."
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
outputfile="./inception_v3.circle"
rm -rf ${outputfile}
-rm -rf ${outputfile}.log
+rm -rf ${filename}.log
# run test
one-import tf \
--input_path ${inputfile} \
--output_path ${outputfile} \
--input_arrays input2 --input_shapes "1,299,299,3" \
---output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log
+--output_arrays InceptionV3/Predictions/Reshape_1 > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
# to create inception_v3.circle
if [[ ! -s ${inputfile} ]]; then
- /bin/bash one-import_001.test >> /dev/null
+ /bin/bash one-import_001.test > /dev/null 2>&1
return_code=$?
if [[ ${return_code} != 0 ]]; then
trap_err_onexit
# run test
one-optimize --O1 \
--input_path ${inputfile} \
---output_path ${outputfile} >> /dev/null
+--output_path ${outputfile} > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+outputfile="./inception_v3-opt.circle"
+
+rm -rf ${outputfile}
+
+# to create inception_v3.circle
+if [[ ! -s ${inputfile} ]]; then
+ /bin/bash one-import_001.test > /dev/null 2>&1
+ return_code=$?
+ if [[ ${return_code} != 0 ]]; then
+ trap_err_onexit
+ fi
+fi
+
+# run test
+one-optimize --O1 \
+--change_outputs InceptionV3/Logits/SpatialSqueeze1 \
+--input_path ${inputfile} \
+--output_path ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
# run test
one-optimize --O1 \
--input_path ${inputfile} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
# run test
one-optimize --O1 \
--input_path ${inputfile} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# this test should fail
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Change outputs failed" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+outputfile="./inception_v3-opt.circle"
+
+rm -rf ${outputfile}
+rm -rf ${filename}.log
+
+# run test
+one-optimize --O1 \
+--change_outputs non_existing_node_name \
+--input_path ${inputfile} \
+--output_path ${outputfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
# to create inception_v3.circle
if [[ ! -s ${inputfile} ]]; then
- /bin/bash one-import_001.test >> /dev/null
+ /bin/bash one-import_001.test > /dev/null 2>&1
return_code=$?
if [[ ${return_code} != 0 ]]; then
trap_err_onexit
# run test
one-pack \
-i ${inputfile} \
--o ${outputfolder} >> /dev/null
+-o ${outputfolder} > /dev/null 2>&1
if [[ ! -d "${outputfolder}" ]]; then
trap_err_onexit
# run test
one-pack \
-i ./inception_v2.circle \
--o nnpack > ${filename}.log
+-o nnpack > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
# run test
one-pack \
-i ./sample \
--o nnpack > ${filename}.log
+-o nnpack > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/help-profile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+# copy help-profile to bin folder
+cp help-profile ../bin/help-profile
+
+# run test
+one-profile -b help -- -h > ${filename}.log
+
+rm -rf ../bin/help-profile
+
+if grep -q "HELP MESSAGE!!" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+fi
+
+trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# run one-codegen with dummy-profile driver
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-profile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="sample.tvn"
+
+if [[ ! -s "${inputfile}" ]]; then
+ touch ${inputfile}
+fi
+
+# copy dummy-profile to bin folder
+cp dummy-profile ../bin/dummy-profile
+
+# run test
+one-profile -b dummy ${inputfile} > ${filename}.log
+
+rm -rf ../bin/dummy-profile
+
+if grep -q "dummy-profile dummy output!!!" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+fi
+
+trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# print one-profile's help message
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+# run test
+one-profile -h > ${filename}.log
+
+if grep -q "command line tool for profiling backend model" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+fi
+
+trap_err_onexit
--- /dev/null
+[one-profile]
+backend=dummy
+command=sample.tvn
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-profile with configuration input
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-profile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-profile_004.cfg"
+inputfile="sample.tvn"
+
+if [[ ! -s "${inputfile}" ]]; then
+ touch ${inputfile}
+fi
+
+# copy dummy-profile to bin folder
+cp dummy-profile ../bin/dummy-profile
+
+# run test
+one-profile -C ${configfile} > ${filename}.log
+
+rm -rf ../bin/dummy-profile
+
+if grep -q "dummy-profile dummy output!!!" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+fi
+
+trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with no input
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "error: the following arguments are required: -b/--backend" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+# run test
+one-profile > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
# to create inception_v3.circle
if [[ ! -s ${inputfile} ]]; then
- /bin/bash one-import_001.test >> /dev/null
+ /bin/bash one-import_001.test > /dev/null 2>&1
return_code=$?
if [[ ${return_code} != 0 ]]; then
trap_err_onexit
--quantized_dtype uint8 \
--input_path ./inception_v3.circle \
--input_data ./inception_v3_test_data.h5 \
---output_path ./inception_v3.quantized.circle >> /dev/null
+--output_path ./inception_v3.quantized.circle > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
# to create inception_v3.circle
if [[ ! -s ${inputfile} ]]; then
- /bin/bash one-import_001.test >> /dev/null
+ /bin/bash one-import_001.test > /dev/null 2>&1
return_code=$?
if [[ ${return_code} != 0 ]]; then
trap_err_onexit
--input_dtype float32 \
--quantized_dtype uint8 \
--input_path ./inception_v3.circle \
---output_path ./inception_v3.random.quantized.circle >> /dev/null
+--output_path ./inception_v3.random.quantized.circle > /dev/null 2>&1
if [[ ! -s "${outputfile}" ]]; then
trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+outputfile="./inception_v3.list.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test with list-format input data (datalist.txt)
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ./inception_v3.circle \
+--input_data ./datalist.txt \
+--input_data_format list \
+--output_path ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+outputfile="./inception_v3.directory.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test with directory-format input data (raw_files)
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ${inputfile} \
+--input_data ./raw_files \
+--input_data_format directory \
+--output_path ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint16 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ./mobilenet_test_data.h5 \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--quantized_dtype uint8 \
--input_path ${inputfile} \
--input_data ${inputdata} \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--mode average \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--max_percentile 101 \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--max_percentile -1 \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--min_percentile 101 \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--min_percentile -1 \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--input_path ${inputfile} \
--input_data ${inputdata} \
--granularity layered \
---output_path ${outputfile} > ${filename}.log
+--output_path ${outputfile} > ${filename}.log 2>&1
echo "${filename_ext} FAILED"
exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with invalid min_percentile
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Given data file is not HDF5" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+inputdata="./datalist.txt"
+outputfile="./inception_v3.quantized.circle"
+
+rm -rf ${outputfile}.log
+
+# run test
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ${inputfile} \
+--input_data ${inputdata} \
+--input_data_format h5 \
+--granularity channel \
+--output_path ${outputfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with invalid min_percentile
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Cannot open file" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+inputdata="./inception_v3_test_data.h5"
+outputfile="./inception_v3.quantized.circle"
+
+rm -rf ${outputfile}.log
+
+# run test
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ${inputfile} \
+--input_data ${inputdata} \
+--input_data_format list \
+--granularity channel \
+--output_path ${outputfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with invalid min_percentile
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Unsupported input data format" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+inputdata="./datalist.txt"
+outputfile="./inception_v3.quantized.circle"
+
+rm -rf ${outputfile}.log
+
+# run test
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ${inputfile} \
+--input_data ${inputdata} \
+--input_data_format h5list \
+--granularity channel \
+--output_path ${outputfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with invalid min_percentile
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Cannot open directory" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="./inception_v3.circle"
+inputdata="./inception_v3_test_data.h5"
+outputfile="./inception_v3.quantized.circle"
+
+rm -rf ${outputfile}.log
+
+# run test
+one-quantize \
+--input_dtype float32 \
+--quantized_dtype uint8 \
+--input_path ${inputfile} \
+--input_data ${inputdata} \
+--input_data_format directory \
+--granularity channel \
+--output_path ${outputfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=True
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-optimize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_001.cfg"
+outputfile="inception_v3.opt.circle"
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=True
+one-quantize=False
+one-pack=True
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-pack]
+input_path=inception_v3.opt.circle
+output_path=inception_v3_pkg
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-optimize -> one-pack
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_002.cfg"
+outputfile="inception_v3_pkg"
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=True
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-quantize]
+input_path=inception_v3.circle
+output_path=inception_v3.quantized.circle
+input_data=inception_v3_test_data.h5
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-quantize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_003.cfg"
+outputfile="inception_v3.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=True
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-codegen]
+backend=dummy
+command=-o sample.tvn inception_v3.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-codegen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-compile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_004.cfg"
+outputfile="sample.tvn"
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+rm -rf ../bin/dummy-compile
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=True
+one-import-bcq=False
+one-optimize=True
+one-quantize=False
+one-pack=False
+one-codegen=True
+
+[one-import-tflite]
+input_path=inception_v3.tflite
+output_path=inception_v3.circle
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-codegen]
+backend=dummy
+command=-o sample.tvn inception_v3.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tflite -> one-optimize -> one-codgen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-compile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_005.cfg"
+outputfile="sample.tvn"
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+rm -rf ../bin/dummy-compile
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=True
+one-quantize=True
+one-pack=False
+one-codegen=True
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-quantize]
+input_path=inception_v3.opt.circle
+output_path=inception_v3.quantized.circle
+input_data=inception_v3_test_data.h5
+
+[one-codegen]
+backend=dummy
+command=-o sample.tvn inception_v3.quantized.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-optimize -> one-quantize -> one-codegen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-compile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_006.cfg"
+outputfile="sample.tvn"
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+rm -rf ../bin/dummy-compile
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=True
+one-pack=True
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-quantize]
+input_path=inception_v3.opt.circle
+output_path=inception_v3.quantized.circle
+input_data=inception_v3_test_data.h5
+
+[one-pack]
+input_path=inception_v3.quantized.circle
+output_path=inception_v3_pkg
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-optimize -> one-quantize -> one-pack
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_007.cfg"
+outputfile="inception_v3_pkg"
+
+rm -rf ${outputfile}
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=True
+one-quantize=False
+one-pack=False
+one-codegen=True
+
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
+
+[one-optimize]
+input_path=test_onnx_model.circle
+output_path=test_onnx_model.opt.circle
+all=True
+remove_redundant_transpose=True
+
+[one-codegen]
+backend=dummy
+command=-o test_onnx_model.bin test_onnx_model.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-optimize -> one-quantize -> one-codegen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-compile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_008.cfg"
+outputfile="test_onnx_model.bin"
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+rm -rf ../bin/dummy-compile
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=True
+one-quantize=False
+one-pack=False
+one-codegen=True
+
+[one-import-onnx]
+input_path=onnx_conv2d_conv2d.onnx
+output_path=onnx_conv2d_conv2d.circle
+
+[one-optimize]
+input_path=onnx_conv2d_conv2d.circle
+output_path=onnx_conv2d_conv2d.opt.circle
+all=True
+remove_redundant_transpose=True
+convert_nchw_to_nhwc=True
+
+[one-codegen]
+backend=dummy
+command=-o onnx_conv2d_conv2d.bin onnx_conv2d_conv2d.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-onnx -> one-optimize -> one-codegen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-compile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_009.cfg"
+outputfile="onnx_conv2d_conv2d.bin"
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+rm -rf ../bin/dummy-compile
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.alt.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+save_intermediate=True
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf: intermediate file should exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_010.cfg"
+outputfile="inception_v3.alt.circle"
+intermfile="inception_v3.alt.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+if [[ ! -s "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
+save_intermediate=True
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-onnx
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_011.cfg"
+outputfile="test_onnx_model.circle"
+intermfile="test_onnx_model.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+if [[ ! -s "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=True
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
+
+[one-quantize]
+input_path=inception_v3.circle
+output_path=inception_v3.list.quantized.circle
+input_data=datalist.txt
+input_data_format=list
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import-tf -> one-quantize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_012.cfg"
+outputfile="inception_v3.list.quantized.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import {tf} with config file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_013.cfg"
+outputfile="inception_v3.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc import tf -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-import-tflite]
+input_path=inception_v3.tflite
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v1
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import {tflite} with config file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_014.cfg"
+outputfile="inception_v3.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc import tflite -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-import-bcq]
+input_path=bcq.pb
+output_path=bcq.circle
+input_arrays=Placeholder
+output_arrays=MatMul
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import {bcq} with config file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_015.cfg"
+outputfile="bcq.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc import bcq -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-import {onnx} with config file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_016.cfg"
+outputfile="test_onnx_model.circle"
+
+rm -rf ${outputfile}
+
+# run test
+onecc import onnx -C ${configfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-optimize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="inception_v3.circle"
+outputfile="inception_v3.opt.circle"
+
+if [[ ! -s "${inputfile}" ]]; then
+ echo "${filename_ext} ERROR: Missing inputfile"
+ trap_err_onexit
+fi
+
+rm -rf ${outputfile}
+
+# run test
+onecc optimize -i ${inputfile} -o ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-quantize
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="inception_v3.opt.circle"
+inputdata="inception_v3_test_data.h5"
+outputfile="inception_v3.quantized.circle"
+
+if [[ ! -s "${inputfile}" && ! -s "${inputdata}" ]]; then
+ echo "${filename_ext} ERROR: Missing inputfile"
+ trap_err_onexit
+fi
+
+rm -rf ${outputfile}
+
+# run test
+onecc quantize -i ${inputfile} -o ${outputfile} -d ${inputdata} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-pack
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="inception_v3.quantized.circle"
+outputfile="inception_v3_pkg"
+
+if [[ ! -s "${inputfile}" ]]; then
+ echo "${filename_ext} ERROR: Missing inputfile"
+ trap_err_onexit
+fi
+
+rm -rf ${outputfile}
+
+# run test
+onecc pack -i ${inputfile} -o ${outputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-codegen
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+inputfile="sample.circle"
+outputfile="sample.dummy"
+
+# prepare dummy file
+touch ${inputfile}
+
+if [[ ! -f "${inputfile}" ]]; then
+ echo "${filename_ext} ERROR: Missing inputfile"
+ trap_err_onexit
+fi
+
+rm -rf ${outputfile}
+
+# copy dummy-compile to bin folder
+cp dummy-compile ../bin/dummy-compile
+
+# run test
+onecc codegen -b dummy -o ${outputfile} ${inputfile} > /dev/null 2>&1
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+one-profile=True
+
+[one-profile]
+backend=dummy
+command=test_onnx_model.bin
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# one-profile
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ rm -rf ../bin/dummy-profile
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_021.cfg"
+
+# copy dummy-profile to bin folder
+cp dummy-profile ../bin/dummy-profile
+
+# run test
+onecc -C ${configfile} > ${filename}.log
+
+rm -rf ../bin/dummy-profile
+
+if grep -q "dummy-profile dummy output!!!" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+fi
+
+trap_err_onexit
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with missing configuration file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "Not found given configuration file" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_001.cfg"
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=True
+one-quantize=False
+one-pack=True
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with missing one-pack section in configuration file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "configuration file must have 'one-pack' section" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_002.cfg"
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-pack]
+input_path=inception_v3.opt.circle
+output_path=inception_v3_pkg
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with missing onecc section in configuration file
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "\[onecc\] section is required in configuration file" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_003.cfg"
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=True
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
+
+[one-optimize]
+input_path=inception_v3.circle
+output_path=inception_v3.opt.circle
+
+[one-optimize]
+input_path=inception_v4.circle
+output_path=inception_v4.opt.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative usage with duplicate section
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "section 'one-optimize' already exists" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_004.cfg"
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+[onecc]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+input_path=inception_v3.pb
+output_path=inception_v3.alt.circle
+input_arrays=input
+input_shapes=1,299,299,3
+output_arrays=InceptionV3/Predictions/Reshape_1
+converter_version=v2
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative one-import-tf intermediate file should not exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_005.cfg"
+outputfile="inception_v3.alt.circle"
+intermfile="inception_v3.alt.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+# output should exist
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+# intermediate file should not exist
+if [[ -f "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+[onecc]
+one-import-tf=False
+one-import-tflite=False
+one-import-bcq=False
+one-import-onnx=True
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-onnx]
+input_path=test_onnx_model.onnx
+output_path=test_onnx_model.circle
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative one-import-tf intermediate file should not exist
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="onecc_neg_006.cfg"
+outputfile="test_onnx_model.circle"
+intermfile="test_onnx_model.tflite"
+
+rm -rf ${outputfile}
+rm -rf ${intermfile}
+
+# run test
+onecc -C ${configfile} > ${filename}.log 2>&1
+
+# output should exist
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+# intermediate file should not exist
+if [[ -f "${intermfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative subcommand
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "unrecognized arguments" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+# run test
+onecc wronginput > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
--- /dev/null
+#!/bin/bash
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# negative subcommand with empty argument
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ if grep -q "error" "${filename}.log"; then
+ echo "${filename_ext} SUCCESS"
+ exit 0
+ fi
+
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+# run test
+onecc > ${filename}.log 2>&1
+
+echo "${filename_ext} FAILED"
+exit 255
('convert_nchw_to_nhwc',
'Experimental: This will convert NCHW operators to NHWC under the assumption that input model is NCHW.'
),
- ('nchw_to_nhwc_preserve_input_shape',
- 'preserve the input shape of the model (argument for convert_nchw_to_nhwc)'),
- ('nchw_to_nhwc_preserve_output_shape',
- 'preserve the output shape of the model (argument for convert_nchw_to_nhwc)'),
+ ('nchw_to_nhwc_input_shape',
+ 'convert the input shape of the model (argument for convert_nchw_to_nhwc)'),
+ ('nchw_to_nhwc_output_shape',
+ 'convert the output shape of the model (argument for convert_nchw_to_nhwc)'),
('fold_add_v2', 'fold AddV2 op with constant inputs'),
('fold_cast', 'fold Cast op with constant input'),
('fold_dequantize', 'fold Dequantize op'),
('fuse_bcq', 'apply Binary Coded Quantization'),
('fuse_preactivation_batchnorm',
'fuse BatchNorm operators of pre-activations to Convolution op'),
+ ('fuse_mean_with_mean', 'fuse two consecutive Mean ops'),
+ ('fuse_transpose_with_mean',
+ 'fuse Mean with a preceding Transpose under certain conditions'),
('make_batchnorm_gamma_positive',
'make negative gamma of BatchNorm to a small positive value (1e-10).'
' Note that this pass can change the execution result of the model.'
('fuse_instnorm', 'fuse ops to InstanceNorm operator'),
('replace_cw_mul_add_with_depthwise_conv',
'replace channel-wise Mul/Add with DepthwiseConv2D'),
+ ('remove_fakequant', 'remove FakeQuant ops'),
+ ('remove_quantdequant', 'remove Quantize-Dequantize sequence'),
('remove_redundant_reshape', 'fuse or remove subsequent Reshape ops'),
('remove_redundant_transpose', 'fuse or remove subsequent Transpose ops'),
('remove_unnecessary_reshape', 'remove unnecessary reshape ops'),
('resolve_customop_batchmatmul',
'convert Custom(BatchMatmul) op to BatchMatmul op'),
('resolve_customop_matmul', 'convert Custom(Matmul) op to Matmul op'),
+ ('resolve_customop_max_pool_with_argmax',
+ 'convert Custom(MaxPoolWithArgmax) to net of builtin operators'),
('shuffle_weight_to_16x1float32',
'convert weight format of FullyConnected op to SHUFFLED16x1FLOAT32.'
' Note that it only converts weights whose row is a multiple of 16'),
('substitute_pack_to_reshape', 'convert single input Pack op to Reshape op'),
('substitute_squeeze_to_reshape', 'convert certain condition Squeeze to Reshape'),
+ ('substitute_strided_slice_to_reshape',
+ 'convert certain condition StridedSlice to Reshape'),
('substitute_transpose_to_reshape',
'convert certain condition Transpose to Reshape'),
- ('transform_min_max_to_relu6', 'transform Minimum-Maximum pattern to Relu6 op'))
+ ('transform_min_max_to_relu6', 'transform Minimum-Maximum pattern to Relu6 op'),
+ ('transform_min_relu_to_relu6', 'transform Minimum(6)-Relu pattern to Relu6 op'))
_CONSTANT = _CONSTANT()
action='store_true',
help='show program\'s version number and exit')
+ # verbose
+ parser.add_argument(
+ '-V',
+ '--verbose',
+ action='store_true',
+ help='output additional information to stdout or stderr')
+
# configuration file
parser.add_argument('-C', '--config', type=str, help='run with configuation file')
# section name that you want to run in configuration file
the option is processed prior to the configuration file."""
if _is_valid_attr(args, 'config'):
config = configparser.ConfigParser()
+ config.optionxform = str
config.read(args.config)
# if section is given, verify given section
if _is_valid_attr(args, 'section'):
def _make_tf2tfliteV2_cmd(args, driver_path, input_path, output_path):
"""make a command for running tf2tfliteV2.py"""
cmd = [sys.executable, os.path.expanduser(driver_path)]
+ # verbose
+ if _is_valid_attr(args, 'verbose'):
+ cmd.append('--verbose')
# model_format
if _is_valid_attr(args, 'model_format_cmd'):
cmd.append(getattr(args, 'model_format_cmd'))
# profiling
if _is_valid_attr(args, 'generate_profile_data'):
cmd.append('--generate_profile_data')
- # optimization pass
+ # optimization pass(only true/false options)
+ # TODO support options whose number of arguments is more than zero
for opt in _CONSTANT.OPTIMIZATION_OPTS:
if _is_valid_attr(args, opt[0]):
- cmd.append('--' + opt[0])
+ # ./driver --opt[0]
+ if type(getattr(args, opt[0])) is bool:
+ cmd.append('--' + opt[0])
+ """
+ This condition check is for config file interface, usually would be
+ SomeOption=True
+ but user can write as follows while development
+ SomeOption=False
+ instead of removing SomeOption option
+ """
+ if type(getattr(args, opt[0])) is str and not getattr(
+ args, opt[0]).lower() in ['false', '0', 'n']:
+ cmd.append('--' + opt[0])
return cmd
# run one-version
subprocess.call([os.path.join(dir_path, 'one-version'), script_name])
sys.exit()
+
+
+def _safemain(main, mainpath):
+ """execute given method and print with program name for all uncaught exceptions"""
+ try:
+ main()
+ except Exception as e:
+ prog_name = os.path.basename(mainpath)
+ print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
+ sys.exit(255)
+
+
+def _run(cmd, err_prefix=None, logfile=None):
+ """Execute command in subprocess
+
+ Args:
+ cmd: command to be executed in subprocess
+ err_prefix: prefix to be put before every stderr lines
+ logfile: file stream to which both of stdout and stderr lines will be written
+ """
+ if logfile == None:
+ with subprocess.Popen(cmd, stderr=subprocess.PIPE, bufsize=1) as p:
+ for line in p.stderr:
+ if err_prefix:
+ line = f"{err_prefix}: ".encode() + line
+ sys.stderr.buffer.write(line)
+ sys.stderr.buffer.flush()
+ else:
+ with subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1) as p:
+ import select
+ inputs = set([p.stdout, p.stderr])
+ while inputs:
+ readable, _, _ = select.select(inputs, [], [])
+ for x in readable:
+ line = x.readline()
+ if len(line) == 0:
+ inputs.discard(x)
+ continue
+ if x == p.stdout:
+ out = sys.stdout
+ if x == p.stderr:
+ out = sys.stderr
+ if err_prefix:
+ line = f"{err_prefix}: ".encode() + line
+ out.buffer.write(line)
+ out.buffer.flush()
+ logfile.write(line)
+ if p.returncode != 0:
+ sys.exit(p.returncode)
auto stream = std::make_unique<T>(path.c_str(), mode);
if (!stream->is_open())
{
- throw std::runtime_error{"ERROR: Failed to open " + path};
+ throw std::runtime_error{"Failed to open " + path};
}
return stream;
}
--- /dev/null
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(pepper_csv2vec STATIC ${SOURCES})
+set_target_properties(pepper_csv2vec PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_include_directories(pepper_csv2vec PUBLIC include)
+target_link_libraries(pepper_csv2vec PRIVATE nncc_common)
+target_link_libraries(pepper_csv2vec PUBLIC nncc_coverage)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+# Google Test is mandatory for test
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(pepper_csv2vec_test ${TESTS})
+target_link_libraries(pepper_csv2vec_test pepper_csv2vec)
--- /dev/null
+# pepper-csv2vec
+
+Returns `std::vector<T>` from CSV format string input.
* limitations under the License.
*/
-#ifndef __CIRCLE_HELPER_STRINGS_H__
-#define __CIRCLE_HELPER_STRINGS_H__
+#ifndef __PEPPER_CSV2VEC_H__
+#define __PEPPER_CSV2VEC_H__
#include <string>
#include <vector>
-namespace partee
+namespace pepper
{
template <typename T> std::vector<T> csv_to_vector(const std::string &str);
-bool is_one_of(const std::string &item, const std::vector<std::string> &items);
+template <typename T> bool is_one_of(const T &item, const std::vector<T> &items);
-} // namespace partee
+} // namespace pepper
-#endif // __CIRCLE_HELPER_STRINGS_H__
+#endif // __PEPPER_CSV2VEC_H__
* limitations under the License.
*/
-#include "HelperStrings.h"
+#include "pepper/csv2vec.h"
#include <algorithm>
#include <sstream>
+#include <cassert>
-namespace partee
+namespace pepper
{
template <> std::vector<std::string> csv_to_vector(const std::string &str)
return ret;
}
-bool is_one_of(const std::string &item, const std::vector<std::string> &items)
+// TODO merge std::string and int32_t type
+
+template <> std::vector<int32_t> csv_to_vector(const std::string &str)
+{
+ std::vector<int32_t> ret;
+ std::istringstream is(str);
+ for (int32_t i; is >> i;)
+ {
+ assert(i != ',');
+ ret.push_back(i);
+ if (is.peek() == ',')
+ is.ignore();
+ }
+ return ret;
+}
+
+template <> bool is_one_of(const std::string &item, const std::vector<std::string> &items)
{
return std::find(items.begin(), items.end(), item) != items.end();
}
-} // namespace partee
+} // namespace pepper
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pepper/csv2vec.h"
+
+#include <gtest/gtest.h>
+
+TEST(csv2vec, simple_string)
+{
+ auto ret = pepper::csv_to_vector<std::string>("hello,world");
+
+ ASSERT_EQ(2, ret.size());
+ ASSERT_TRUE("hello" == ret.at(0));
+ ASSERT_TRUE("world" == ret.at(1));
+}
+
+TEST(csv2vec, simple_int32)
+{
+ auto ret = pepper::csv_to_vector<int32_t>("1,2,3");
+
+ ASSERT_EQ(3, ret.size());
+ ASSERT_EQ(1, ret.at(0));
+ ASSERT_EQ(3, ret.at(2));
+}
+
+TEST(csv2vec, is_one_of)
+{
+ auto ret = pepper::csv_to_vector<std::string>("hello,world");
+
+ ASSERT_TRUE(pepper::is_one_of<std::string>("hello", ret));
+ ASSERT_FALSE(pepper::is_one_of<std::string>("good", ret));
+}
+
+TEST(csv2vec, empty_string_NEG)
+{
+ // should not abort
+ EXPECT_NO_THROW(pepper::csv_to_vector<std::string>(""));
+}
+
+TEST(csv2vec, invalid_int32_NEG)
+{
+ auto ret = pepper::csv_to_vector<int32_t>("hello,world");
+
+ ASSERT_EQ(0, ret.size());
+}
{
- "scale": 0.00014586378529202193,
+ "scale": 0.00014983004075475037,
"zero_point": 0.0
}
{
- "scale": 0.00014956798986531794,
+ "scale": 0.00014983004075475037,
"zero_point": 0.0
}
{
- "scale": 0.035256847739219666,
- "zero_point": 123.0
+ "scale": 0.038689617067575455,
+ "zero_point": 128.0
}
{
- "scale": 0.0385618582367897,
- "zero_point": 129.0
+ "scale": 0.038689617067575455,
+ "zero_point": 128.0
}
#!/usr/bin/env bash
-''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
.help("Show version information and exit")
.exit_with(print_version);
+ arser.add_argument("-V", "--verbose")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("output additional information to stdout or stderr");
+
arser.add_argument("--input_model")
.nargs(1)
.type(arser::DataType::STR)
.type(arser::DataType::STR)
.help("Record mode. percentile (default) or moving_average");
+ arser.add_argument("--input_data_format")
+ .nargs(1)
+ .type(arser::DataType::STR)
+ .help("Input data format. h5/hdf5 (default) or list/filelist");
+
arser.add_argument("--generate_profile_data")
.nargs(0)
.required(false)
return 255;
}
+ if (arser.get<bool>("--verbose"))
+ {
+ // The third parameter of setenv means REPLACE.
+ // If REPLACE is zero, it does not overwrite an existing value.
+ setenv("LUCI_LOG", "100", 0);
+ }
+
auto settings = luci::UserSettings::settings();
auto input_model_path = arser.get<std::string>("--input_model");
std::string mode("percentile");
float min_percentile = 1.0;
float max_percentile = 99.0;
+ std::string input_data_format("h5");
if (arser["--min_percentile"])
min_percentile = arser.get<float>("--min_percentile");
if (arser["--generate_profile_data"])
settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
+ if (arser["--input_data_format"])
+ input_data_format = arser.get<std::string>("--input_data_format");
+
RecordMinMax rmm;
// Initialize interpreter and observer
{
auto input_data_path = arser.get<std::string>("--input_data");
- // Profile min/max while executing the given input data
- rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
+ if (input_data_format == "h5" || input_data_format == "hdf5")
+ {
+ // Profile min/max while executing the H5 data
+ rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
+ }
+ // input_data is a text file having a file path in each line.
+ // Each data file is composed of inputs of a model, concatenated in
+ // the same order with the input index of the model
+ //
+ // For example, for a model with n inputs, the contents of each data
+ // file can be visualized as below
+ // [input 1][input 2]...[input n]
+ // |start............end of file|
+ else if (input_data_format == "list" || input_data_format == "filelist")
+ {
+ // Profile min/max while executing the list of Raw data
+ rmm.profileRawData(mode, input_data_path, min_percentile, max_percentile);
+ }
+ else if (input_data_format == "directory" || input_data_format == "dir")
+ {
+ // Profile min/max while executing all files under the given directory
+ // The contents of each file is same as the raw data in the 'list' type
+ rmm.profileRawDataDirectory(mode, input_data_path, min_percentile, max_percentile);
+ }
+ else
+ {
+ throw std::runtime_error(
+ "Unsupported input data format (supported formats: h5/hdf5 (default), list/filelist)");
+ }
}
else
{
if (percentile < 0 || percentile > 100)
throw std::runtime_error("Percentile must be ranged from 0 to 100");
- if (percentile == 0.0)
- return vector.front();
-
- if (percentile == 100.0)
- return vector.back();
-
if (vector.empty())
throw std::runtime_error("Percentile must take a non-empty vector as an argument");
copy.assign(vector.begin(), vector.end());
std::sort(copy.begin(), copy.end());
+ if (percentile == 0.0)
+ return copy.front();
+
+ if (percentile == 100.0)
+ return copy.back();
+
int index = static_cast<int>(std::floor((copy.size() - 1) * percentile / 100.0));
float percent_i = static_cast<float>(index) / static_cast<float>(copy.size() - 1);
void profileData(const std::string &mode, const std::string &input_data_path,
float min_percentile, float max_percentile);
+ void profileRawData(const std::string &mode, const std::string &input_data_path,
+ float min_percentile, float max_percentile);
+
+ void profileRawDataDirectory(const std::string &mode, const std::string &input_data_path,
+ float min_percentile, float max_percentile);
+
void profileDataWithRandomInputs(const std::string &mode, float min_percentile,
float max_percentile);
return;
}
- if (node->opcode() == luci::CircleOpcode::ARG_MAX)
- {
- // Output of arg_max is the index of the largest value across axes of a tensor
- // this should not be quantized
- return;
- }
-
if (node->dtype() == DataType::BOOL)
{
// Bool type tensor is not quantized
// Only support recording of float32 values
if (tensor->element_type() != DataType::FLOAT32)
- throw std::runtime_error("Tensor's data type is not float");
+ {
+ // Exceptions that should be processed in backends
+ switch (node->opcode())
+ {
+ case luci::CircleOpcode::ARG_MAX:
+ // Output of arg_max is the index of the largest value across axes of a tensor.
+ // It always has integer type.
+ case luci::CircleOpcode::CAST:
+ // Cast is quantized only if it converts <type> -> float.
+ // Other cases should be processed in backends.
+ case luci::CircleOpcode::RESHAPE:
+ // Reshape changes only shape of input tensor, efficiently is it a no-op.
+ return;
+ default:
+ throw std::runtime_error("Tensor's data type is not float");
+ }
+ }
const auto data = tensor->data<float>();
const auto num_elements = tensor->shape().num_elements();
if (isnan(number))
continue;
+ // TODO use metadata hints to detect such cases
+ if (number == std::numeric_limits<float>::lowest())
+ continue;
+
all_nan = false;
if (number > max)
#include <luci/CircleFileExpContract.h>
#include <luci/IR/CircleQuantParam.h>
+#include <dirent.h>
#include <algorithm>
#include <cmath>
#include <fstream>
namespace
{
+void readDataFromFile(const std::string &filename, std::vector<char> &data, size_t data_size)
+{
+ assert(data.size() == data_size); // FIX_CALLER_UNLESS
+
+ std::ifstream fs(filename, std::ifstream::binary);
+ if (fs.fail())
+ throw std::runtime_error("Cannot open file \"" + filename + "\".\n");
+ if (fs.read(data.data(), data_size).fail())
+ throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n");
+}
+
std::vector<uint8_t> genRandomBoolData(std::mt19937 &gen, uint32_t num_elements)
{
std::uniform_int_distribution<> dist(0, 1);
model_data.size()};
if (!circle::VerifyModelBuffer(verifier))
{
- throw std::runtime_error("ERROR: Failed to verify circle '" + input_model_path + "'");
+ throw std::runtime_error("Failed to verify circle '" + input_model_path + "'");
}
_module = luci::Importer().importModule(circle::GetModel(model_data.data()));
if (_module == nullptr)
{
- throw std::runtime_error("ERROR: Failed to load '" + input_model_path + "'");
+ throw std::runtime_error("Failed to load '" + input_model_path + "'");
}
// Initialize interpreter
_interpreter->attachObserver(_observer.get());
}
+// input_data_path is a path to the directory
+// The directory should contain binary files each of which is a raw data,
+// ready to be consumed by the input circle model without any modification
+// TODO reduce duplicate codes with profileRawData
+void RecordMinMax::profileRawDataDirectory(const std::string &mode,
+ const std::string &input_data_path, float min_percentile,
+ float max_percentile)
+{
+ struct dirent *entry = nullptr;
+ DIR *dp = nullptr;
+
+ dp = opendir(input_data_path.c_str());
+ if (not dp)
+ throw std::runtime_error("Cannot open directory. Please check \"" + input_data_path +
+ "\" is a directory.\n");
+
+ uint32_t num_records = 0;
+ const auto input_nodes = loco::input_nodes(_module->graph());
+
+ // Get total input size
+ uint32_t total_input_size = 0;
+ for (auto input : input_nodes)
+ {
+ const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ total_input_size += getTensorSize(input_node);
+ }
+
+ while (entry = readdir(dp))
+ {
+ // Skip if the entry is not a regular file
+ if (entry->d_type != DT_REG)
+ continue;
+
+ const std::string filename = entry->d_name;
+ std::cout << "Recording " << num_records << "'th data" << std::endl;
+
+ // Read data from file to buffer
+ // Assumption: For a multi-input model, the binary file should have inputs concatenated in the
+ // same order with the input index.
+ std::vector<char> input_data(total_input_size);
+ readDataFromFile(input_data_path + "/" + filename, input_data, total_input_size);
+
+ // Write data from buffer to interpreter
+ uint32_t offset = 0;
+ for (auto input : input_nodes)
+ {
+ const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ const auto input_size = getTensorSize(input_node);
+ _interpreter->writeInputTensor(input_node, input_data.data() + offset, input_size);
+
+ offset += input_size;
+ }
+
+ _interpreter->interpret();
+
+ num_records++;
+ }
+
+ closedir(dp);
+
+ if (num_records == 0)
+ throw std::runtime_error("The input data file does not contain any record.");
+
+ std::cout << "Recording finished. Number of recorded data: " << num_records << std::endl;
+
+ update_quantparam(_observer.get(), mode, min_percentile, max_percentile);
+}
+
+// input_data_path is a text file which specifies the representative data
+// The text file should contain absolute file path per line.
+// The pointed file should be a binary file containing one representative data,
+// ready to be consumed by the input circle model without any modification
+// NOTE If a model has multiple inputs, the binary file should have inputs concatenated in the same
+// order with the input index of the circle model.
+void RecordMinMax::profileRawData(const std::string &mode, const std::string &input_data_path,
+ float min_percentile, float max_percentile)
+{
+ std::ifstream input_file(input_data_path);
+ if (input_file.fail())
+ throw std::runtime_error("Cannot open file \"" + input_data_path + "\".\n");
+
+ std::string record;
+ uint32_t num_records = 0;
+ const auto input_nodes = loco::input_nodes(_module->graph());
+
+ // Get total input size
+ uint32_t total_input_size = 0;
+ for (auto input : input_nodes)
+ {
+ const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ total_input_size += getTensorSize(input_node);
+ }
+
+ while (getline(input_file, record))
+ {
+ std::cout << "Recording " << num_records << "'th data" << std::endl;
+
+ // Read data from file to buffer
+ // Assumption: For a multi-input model, the binary file should have inputs concatenated in the
+ // same order with the input index.
+ std::vector<char> input_data(total_input_size);
+ readDataFromFile(record, input_data, total_input_size);
+
+ // Write data from buffer to interpreter
+ uint32_t offset = 0;
+ for (auto input : input_nodes)
+ {
+ const auto *input_node = loco::must_cast<const luci::CircleInput *>(input);
+ const auto input_size = getTensorSize(input_node);
+ _interpreter->writeInputTensor(input_node, input_data.data() + offset, input_size);
+
+ offset += input_size;
+ }
+
+ _interpreter->interpret();
+
+ num_records++;
+ }
+
+ if (num_records == 0)
+ throw std::runtime_error("The input data file does not contain any record.");
+
+ std::cout << "Recording finished. Number of recorded data: " << num_records << std::endl;
+
+ update_quantparam(_observer.get(), mode, min_percentile, max_percentile);
+}
+
void RecordMinMax::profileData(const std::string &mode, const std::string &input_data_path,
float min_percentile, float max_percentile)
{
if (num_inputs != importer.numInputs(record_idx))
throw std::runtime_error("Wrong number of inputs.");
- if (record_idx % 100 == 0)
- std::cout << "Recording " << record_idx << "'th data" << std::endl;
+ std::cout << "Recording " << record_idx << "'th data" << std::endl;
for (int32_t input_idx = 0; input_idx < num_inputs; input_idx++)
{
if (!exporter.invoke(&contract))
{
- throw std::runtime_error("ERROR: Failed to export '" + output_model_path + "'");
+ throw std::runtime_error("Failed to export '" + output_model_path + "'");
}
}
std::vector<T> _values;
};
+template <> class ExplicitDataChef<std::string> final : public DataChef
+{
+public:
+ ExplicitDataChef()
+ {
+ // DO NOTHING
+ }
+
+public:
+ std::vector<uint8_t> generate(int32_t count) const override;
+
+public:
+ void insert(const std::string &value) { _values.emplace_back(value); }
+
+private:
+ void write_value(std::vector<uint8_t> &res, int32_t value) const;
+
+private:
+ std::vector<std::string> _values;
+};
+
template <typename T> struct ExplicitDataChefFactory : public DataChefFactory
{
std::unique_ptr<DataChef> create(const Arguments &args) const
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "souschef/Data/Explicit.h"
+
+#include <string>
+#include <vector>
+
+namespace souschef
+{
+
+/**
+ * @note This emulates TensorFlow int DynamicBuffer::WriteToBuffer(char** buffer) method
+ * Memory structure:
+ * int32_t count
+ * int32_t offsets[count + 1]
+ * string values[count]
+ * where string is like std::string without ending null byte
+ */
+std::vector<uint8_t> ExplicitDataChef<std::string>::generate(int32_t count) const
+{
+ std::vector<uint8_t> res;
+
+ // write count
+ write_value(res, count);
+
+ // write first item offset
+ int32_t start = sizeof(int32_t) * (count + 2);
+ write_value(res, start);
+
+ // write succeeding items offset (or the end)
+ int32_t offset = start;
+ for (uint32_t n = 0; n < count; ++n)
+ {
+ std::string const value = (n < _values.size()) ? _values.at(n) : std::string{};
+ offset += value.length();
+ write_value(res, offset);
+ }
+
+ for (uint32_t n = 0; n < count; ++n)
+ {
+ std::string const value = (n < _values.size()) ? _values.at(n) : std::string{};
+ const uint8_t *arr = reinterpret_cast<const uint8_t *>(value.c_str());
+
+ for (uint32_t b = 0; b < value.length(); ++b)
+ {
+ res.emplace_back(arr[b]);
+ }
+ }
+
+ return res;
+}
+
+void ExplicitDataChef<std::string>::write_value(std::vector<uint8_t> &res, int32_t value) const
+{
+ const uint8_t *arr = reinterpret_cast<const uint8_t *>(&value);
+
+ for (uint32_t b = 0; b < sizeof(int32_t); ++b)
+ {
+ res.emplace_back(arr[b]);
+ }
+}
+
+} // namespace souschef
}
template <> bool to_number(const std::string &s)
{
- if (std::stoi(s) || s == "T" || s == "t" || s == "TRUE" || s == "true")
+ if (s == "T" || s == "t" || s == "TRUE" || s == "true" || s == "1")
return true;
- return false;
+ if (s == "F" || s == "f" || s == "FALSE" || s == "false" || s == "0")
+ return false;
+ throw std::invalid_argument("Unsupported boolean argument");
}
+template <> std::string to_number(const std::string &s) { return s; }
} // namespace souschef
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import tensorflow as tf
import argparse
import sys
parser = argparse.ArgumentParser(
description=("Command line tool to run TensorFlow Lite Converter."))
+ # Verbose
+ parser.add_argument(
+ "-V",
+ "--verbose",
+ action="store_true",
+ help="output additional information to stdout or stderr")
+
# Converter version.
converter_version = parser.add_mutually_exclusive_group(required=True)
converter_version.add_argument(
open(flags.output_path, "wb").write(tflite_model)
+def _apply_verbosity(verbosity):
+ # NOTE
+ # TF_CPP_MIN_LOG_LEVEL
+ # 0 : INFO + WARNING + ERROR + FATAL
+ # 1 : WARNING + ERROR + FATAL
+ # 2 : ERROR + FATAL
+ # 3 : FATAL
+ #
+ # TODO Find better way to suppress trackback on error
+ # tracebacklimit
+ # The default is 1000.
+ # When set to 0 or less, all traceback information is suppressed
+ if verbosity:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
+ sys.tracebacklimit = 1000
+ else:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+ sys.tracebacklimit = 0
+
+
def _convert(flags):
+ _apply_verbosity(flags.verbose)
+
if (flags.v1):
_v1_convert(flags)
else:
if __name__ == "__main__":
- main()
+ try:
+ main()
+ except Exception as e:
+ prog_name = os.path.basename(__file__)
+ print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
+ sys.exit(255)
auto stream = std::make_unique<T>(path.c_str(), mode);
if (!stream->is_open())
{
- throw std::runtime_error{"ERROR: Failed to open " + path};
+ throw std::runtime_error{"Failed to open " + path};
}
return stream;
}
return tflite::TensorType_UINT8;
case tflchef::INT64:
return tflite::TensorType_INT64;
+ case tflchef::STRING:
+ return tflite::TensorType_STRING;
case tflchef::BOOL:
return tflite::TensorType_BOOL;
case tflchef::INT16:
* limitations under the License.
*/
-#include "MaxPoolWithArgMax.h"
+#include "MaxPoolWithArgmax.h"
#include "flatbuffers/flexbuffers.h"
-flatbuffers::Offset<void> MaxPoolWithArgMaxChef::value(flatbuffers::FlatBufferBuilder &fbb) const
+flatbuffers::Offset<void> MaxPoolWithArgmaxChef::value(flatbuffers::FlatBufferBuilder &fbb) const
{
return flatbuffers::Offset<void>();
}
flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
-MaxPoolWithArgMaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
+MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);
- assert(operation.type() == "MaxPoolWithArgMax");
+ assert(operation.type() == "MaxPoolWithArgmax");
/**
* REGISTER_OP("MaxPoolWithArgmax")
}
std::unique_ptr<OpChef>
-MaxPoolWithArgMaxChefFactory::create(const tflchef::Operation *operation) const
+MaxPoolWithArgmaxChefFactory::create(const tflchef::Operation *operation) const
{
- return std::unique_ptr<OpChef>{new MaxPoolWithArgMaxChef{operation}};
+ return std::unique_ptr<OpChef>{new MaxPoolWithArgmaxChef{operation}};
}
#include "OpChef.h"
-class MaxPoolWithArgMaxChef final : public OpChef
+class MaxPoolWithArgmaxChef final : public OpChef
{
public:
- explicit MaxPoolWithArgMaxChef(const tflchef::Operation *operation) : _operation{operation}
+ explicit MaxPoolWithArgmaxChef(const tflchef::Operation *operation) : _operation{operation}
{
// DO NOTHING
}
const tflchef::Operation *_operation;
};
-struct MaxPoolWithArgMaxChefFactory final : public OpChefFactory
+struct MaxPoolWithArgmaxChefFactory final : public OpChefFactory
{
std::unique_ptr<OpChef> create(const tflchef::Operation *operation) const override;
};
--- /dev/null
+#ifndef DATA_CHEF
+#error "Define DATA_CHEF first"
+#endif // DATA_CHEF
+
+// DATA_CHEF(TYPE, NAME, FACTORY_CLASS)
+// "TYPE" SHOULD BE an enum tag of tflchef::TensorType
+DATA_CHEF(FLOAT32, constant, ConstantDataChefFactory<float>)
+DATA_CHEF(BOOL, constant, ConstantDataChefFactory<bool>)
+DATA_CHEF(UINT8, constant, ConstantDataChefFactory<uint8_t>)
+DATA_CHEF(INT16, constant, ConstantDataChefFactory<int16_t>)
+DATA_CHEF(INT32, constant, ConstantDataChefFactory<int32_t>)
+DATA_CHEF(INT64, constant, ConstantDataChefFactory<int64_t>)
+DATA_CHEF(INT64, explicit, ExplicitDataChefFactory<int64_t>)
+DATA_CHEF(INT32, explicit, ExplicitDataChefFactory<int32_t>)
+DATA_CHEF(INT16, explicit, ExplicitDataChefFactory<int16_t>)
+DATA_CHEF(UINT8, explicit, ExplicitDataChefFactory<uint8_t>)
+DATA_CHEF(BOOL, explicit, ExplicitDataChefFactory<bool>)
+DATA_CHEF(FLOAT32, explicit, ExplicitDataChefFactory<float>)
+DATA_CHEF(STRING, explicit, ExplicitDataChefFactory<std::string>)
+DATA_CHEF(FLOAT32, gaussian, GaussianFloat32DataChefFactory)
+DATA_CHEF(INT32, gaussian, GaussianInt32DataChefFactory)
+DATA_CHEF(INT16, gaussian, GaussianInt16DataChefFactory)
+DATA_CHEF(UINT8, gaussian, GaussianUint8DataChefFactory)
static DataChefRegistry s64;
static DataChefRegistry fp32;
static DataChefRegistry u8;
+ static DataChefRegistry string;
static DataChefRegistry boolean;
static DataChefRegistry s16;
return fp32;
case tflchef::UINT8:
return u8;
+ case tflchef::STRING:
+ return string;
case tflchef::BOOL:
return boolean;
case tflchef::INT16:
#define DATA_CHEF(TYPE, NAME, FACTORY_CLASS) \
data_chef_registry(::tflchef::TYPE) \
.add(#NAME, std::unique_ptr<FACTORY_CLASS>(new FACTORY_CLASS()));
-#include <souschef/DataChef.def>
+#include "DataChef.def"
#undef DATA_CHEF
//
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Quantize.h"
+
+flatbuffers::Offset<void> QuantizeChef::value(flatbuffers::FlatBufferBuilder &fbb) const
+{
+ return flatbuffers::Offset<void>();
+}
+
+std::unique_ptr<OpChef> QuantizeChefFactory::create(const tflchef::Operation *operation) const
+{
+ return std::unique_ptr<OpChef>{new QuantizeChef{operation}};
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_QUANTIZE_H__
+#define __OP_QUANTIZE_H__
+
+#include "OpChef.h"
+
+class QuantizeChef final : public OpChef
+{
+public:
+ explicit QuantizeChef(const tflchef::Operation *operation) : _operation{operation}
+ {
+ // DO NOTHING
+ }
+
+public:
+ tflite::BuiltinOperator code(void) const override { return tflite::BuiltinOperator_QUANTIZE; }
+
+ tflite::BuiltinOptions type(void) const override { return tflite::BuiltinOptions_NONE; }
+
+ flatbuffers::Offset<void> value(flatbuffers::FlatBufferBuilder &fbb) const override;
+
+private:
+ const tflchef::Operation *_operation;
+};
+
+struct QuantizeChefFactory final : public OpChefFactory
+{
+ std::unique_ptr<OpChef> create(const tflchef::Operation *operation) const override;
+};
+
+#endif // __OP_DEQUANTIZE_H__
OP_CHEF(PadV2, PadV2ChefFactory)
OP_CHEF(Pow, PowChefFactory)
OP_CHEF(PRelu, PReluChefFactory)
+OP_CHEF(Quantize, QuantizeChefFactory)
OP_CHEF(Range, RangeChefFactory)
OP_CHEF(Rank, RankChefFactory)
OP_CHEF(ReduceAny, ReduceAnyChefFactory)
OP_CHEF(BroadcastTo, BroadcastToChefFactory)
OP_CHEF(MatMul, MatMulChefFactory)
OP_CHEF(MatrixBandPart, MatrixBandPartChefFactory)
-OP_CHEF(MaxPoolWithArgMax, MaxPoolWithArgMaxChefFactory)
+OP_CHEF(MaxPoolWithArgmax, MaxPoolWithArgmaxChefFactory)
#include "Op/ReverseV2.h"
#include "Op/Round.h"
#include "Op/Rsqrt.h"
+#include "Op/Quantize.h"
#include "Op/ScatterNd.h"
#include "Op/SegmentSum.h"
#include "Op/Select.h"
#include "CustomOp/BroadcastTo.h"
#include "CustomOp/MatMul.h"
#include "CustomOp/MatrixBandPart.h"
-#include "CustomOp/MaxPoolWithArgMax.h"
+#include "CustomOp/MaxPoolWithArgmax.h"
#endif // __OP_CHEFS_H__
INT32 = 2;
UINT8 = 3;
INT64 = 4;
+ STRING = 5;
BOOL = 6;
INT16 = 7;
}
// NONE
}
-message MaxPoolWithArgMaxOptions {
+message MaxPoolWithArgmaxOptions {
optional Padding padding = 1 [default = VALID];
optional int32 stride_w = 2 [default = 1];
optional int32 stride_h = 3 [default = 1];
optional SegmentSumOptions segment_sum_options = 206;
optional AddNOptions add_n_options = 207;
optional MatMulOptions matmul_options = 208;
- optional MaxPoolWithArgMaxOptions max_pool_with_argmax_options = 209;
+ optional MaxPoolWithArgmaxOptions max_pool_with_argmax_options = 209;
// NOTE if there are more than two options with same type of Options
// use the number not listed in the above reserve list
}
--- /dev/null
+operand {
+ name: "ifm1"
+ type: BOOL
+ shape { dim: 6 }
+}
+operand {
+ name: "ifm2"
+ type: BOOL
+ shape { dim: 6 }
+ filler {
+ tag: "explicit"
+ arg: "T"
+ arg: "f"
+ arg: "0"
+ arg: "1"
+ arg: "true"
+ arg: "FALSE"
+ }
+}
+operand {
+ name: "ofm"
+ type: BOOL
+ shape { dim: 6 }
+}
+operation {
+ type: "LogicalAnd"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+}
+output: "ofm"
--- /dev/null
+operand {
+ name: "ifm"
+ type: STRING
+ shape { }
+}
+operand {
+ name: "suffix"
+ type: STRING
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "Hello"
+ }
+}
+operand {
+ name: "ofm"
+ type: STRING
+ shape { }
+}
+operation {
+ type: "Add"
+ input: "ifm"
+ input: "suffix"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Quantize.h"
+
+#include "Convert.h"
+
+namespace tflchef
+{
+
+void TFliteOpQuantize::filler(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const
+{
+ // Nothing to do with filler
+}
+
+tflchef::Operation *TFliteOpQuantize::build(const tflite::Operator *, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const
+{
+ auto operation = model_recipe->add_operation();
+
+ operation->set_type("Quantize");
+
+ return operation;
+}
+
+} // namespace tflchef
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TFLITE_OP_QUANTIZE_H__
+#define __TFLITE_OP_QUANTIZE_H__
+
+#include "TFliteOpChef.h"
+
+namespace tflchef
+{
+
+/**
+ * @brief tflchef operator builder for Quantize
+ */
+class TFliteOpQuantize : public TFliteOpChef
+{
+public:
+ void filler(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const override;
+ tflchef::Operation *build(const tflite::Operator *op, TFliteImport *import,
+ tflchef::ModelRecipe *model_recipe) const override;
+};
+
+} // namespace tflchef
+
+#endif // __TFLITE_OP_QUANTIZE_H__
#include "Op/PadV2.h"
#include "Op/Pow.h"
#include "Op/PRelu.h"
+#include "Op/Quantize.h"
#include "Op/Range.h"
#include "Op/Rank.h"
#include "Op/ReduceAny.h"
REG_TFL_OP(PADV2, TFliteOpPadV2);
REG_TFL_OP(POW, TFliteOpPow);
REG_TFL_OP(PRELU, TFliteOpPRelu);
+ REG_TFL_OP(QUANTIZE, TFliteOpQuantize);
REG_TFL_OP(RANGE, TFliteOpRange);
REG_TFL_OP(RANK, TFliteOpRank);
REG_TFL_OP(REDUCE_ANY, TFliteOpReduceAny);
os << std::boolalpha;
os << "align_corners(" << resize_params->align_corners() << ")";
os << "half_pixel_centers(" << resize_params->half_pixel_centers() << ")";
+ os << std::noboolalpha;
os << std::endl;
}
}
os << " ";
os << std::boolalpha;
os << "align_corners(" << resize_params->align_corners() << ")";
+ os << std::noboolalpha;
os << std::endl;
}
}
.help("Show version information and exit")
.exit_with(print_version);
+ arser.add_argument("-V", "--verbose")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("output additional information to stdout or stderr");
+
arser.add_argument("tflite")
.nargs(1)
.type(arser::DataType::STR)
}
catch (const std::runtime_error &err)
{
- std::cout << err.what() << std::endl;
+ std::cerr << err.what() << std::endl;
std::cout << arser;
return 255;
}
model._data.size()};
if (!tflite::VerifyModelBuffer(verifier))
{
- throw std::runtime_error("ERROR: Failed to verify tflite");
+ throw std::runtime_error("Failed to verify tflite");
}
_operator_codes_offset =
if (NOT VCONONE_VERSION)
- set(VCONONE_VERSION 0x00000000000f0001)
+ set(VCONONE_VERSION 0x0000000100110000)
# NOTE order is [build patch minor major]
# if VCONONE_VERSION is set with -D option, it will be cached
# you may have to remove cache file if you remove -D option
struct PoolParams
{
- FusedActivationFunctionType activation;
- PaddingType padding_type;
PaddingValues padding_values;
int stride_height;
int stride_width;
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_FLOOR_DIV_H__
+#define __NNFW_CKER_FLOOR_DIV_H__
+
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+template <typename T>
+inline void FloorDivBroadcast(const Shape &unextended_input1_shape, const T *input1_data,
+ const Shape &unextended_input2_shape, const T *input2_data,
+ const Shape &unextended_output_shape, T *output_data)
+{
+ assert(unextended_input1_shape.DimensionsCount() <= 4);
+ assert(unextended_input2_shape.DimensionsCount() <= 4);
+ assert(unextended_output_shape.DimensionsCount() <= 4);
+ const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
+
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
+ &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b)
+ {
+ for (int y = 0; y < output_shape.Dims(1); ++y)
+ {
+ for (int x = 0; x < output_shape.Dims(2); ++x)
+ {
+ for (int c = 0; c < output_shape.Dims(3); ++c)
+ {
+ auto out_idx = Offset(output_shape, b, y, x, c);
+ auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
+ auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
+ auto in1_val = input1_data[in1_idx];
+ auto in2_val = input2_data[in2_idx];
+ output_data[out_idx] = std::floor(
+ std::divides<double>()(static_cast<double>(in1_val), static_cast<double>(in2_val)));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void FloorDivElementwise(const Shape &shape, const T *input1_data, const T *input2_data,
+ T *output_data)
+{
+
+ int num_elements = shape.FlatSize();
+
+ for (int t = 0; t < num_elements; t++)
+ {
+ output_data[t] = std::floor(std::divides<double>()(static_cast<double>(input1_data[t]),
+ static_cast<double>(input2_data[t])));
+ }
+}
+
+} // namespace cker
+
+} // namespace nnfw
+#endif
namespace optimized
{
+std::mutex _gemmlowp_mutex;
+
struct GemmlowpOutputPipeline
{
typedef gemmlowp::VectorMap<const int32_t, gemmlowp::VectorShape::Col> ColVectorMap;
const auto &output_pipeline =
GemmlowpOutputPipeline::MakeExp(bias_data, output_rows, output_offset, output_multiplier,
output_shift, output_activation_min, output_activation_max);
+
+ std::lock_guard<std::mutex> lock_guard(_gemmlowp_mutex);
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, input_offset,
output_pipeline);
author = 'Samsung Research & contributors'
# The full version, including alpha/beta/rc tags
-release = '1.15.0'
+release = '1.17.0'
# -- General configuration ---------------------------------------------------
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.10
+====
.. toctree::
:maxdepth: 2
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.11
+====
.. toctree::
:maxdepth: 2
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.12
+====
.. toctree::
:maxdepth: 2
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.13
+====
.. toctree::
:maxdepth: 2
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.14
+====
.. toctree::
:maxdepth: 2
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
-===
+1.15
+====
.. toctree::
:maxdepth: 2
--- /dev/null
+.. ONE documentation master file, created by
+ sphinx-quickstart on Thu May 20 12:56:12 2021.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+1.16
+====
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ ./release-note-1.16.0.md
+ ./release-note-1.16.1.md
--- /dev/null
+# Release Note 1.16.0
+
+## ONE Compiler
+
+### Compiler Frontend
+
+- Enable `PadV2` in luci-interpreter and quantization
+- Provide `circle-tensordump`, `circledump` as a development tool
+- Provide `luci-eval-driver` as test tool
+- Enable `STRING` type as constant values in CircleConst
+- Fix CircleCustom may have 0 input, 0 output
+- Enable debian package generation
+- More optimization pass
+ - Min(6)+ReLU to ReLU6
+ - Remove FakeQuant Op
+- Experimental support of ONNX upgraded to version 1.8.0 with additional patch
+- Fix bugs where one-cmds' config file didn't evaluate boolean properly
--- /dev/null
+# Release Note 1.16.1
+
+## ONE Compiler
+
+### Compiler Frontend
+
+- Extends the point where `one-codegen` finds backends.
--- /dev/null
+.. ONE documentation master file, created by
+ sphinx-quickstart on Thu May 20 12:56:12 2021.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+1.17
+====
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+ ./release-note-1.17.0.md
--- /dev/null
+# Release Note 1.17.0
+
+## ONE Compiler
+
+### Compiler Frontend
+
+- More optimization pass
+ - Remove Quant-Dequant sequence
+ - Replace Sub with Add
+ - Substitute StridedSlice to Reshape
+ - Fuse Mean with Mean
+ - Fuse Transpose with Mean
+ - Substitute PadV2 to Pad
+- Add new InstanceNorm pattern in `FuseInstanceNormPass`
+- Add verbose option
+- Introduce `onecc` driver to `one-cmds`
+- Introduce `one-profile` driver to `one-cmds`
+
+## ONE Runtime
+
+### gpu_cl backend added
+
+- New backend(gpu_cl) added. This backend exploits tensorflow lite's gpu delegate.
+- This backend supports the following operations : Add, Convolution, Depthwise Convolution, Pooling, Reshape, Relu, Softmax
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
+1.5
===
.. toctree::
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
+1.6
===
.. toctree::
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
+1.7
===
.. toctree::
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
+1.8
===
.. toctree::
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-1.0
+1.9
===
.. toctree::
./1.11/index
./1.12/index
./1.13/index
+ ./1.14/index
+ ./1.15/index
+ ./1.16/index
file(MAKE_DIRECTORY "${TMP_DIR}")
message(STATUS "Download ${PREFIX} from ${URL}")
- file(DOWNLOAD ${URL} "${DOWNLOAD_PATH}"
- STATUS status
- LOG log)
- list(GET status 0 status_code)
- list(GET status 1 status_string)
+ foreach(retry_count RANGE 5)
+ message(STATUS "(Trial Count : ${retry_count})")
- if(NOT status_code EQUAL 0)
- message(FATAL_ERROR "error: downloading '${URL}' failed
+ file(DOWNLOAD ${URL} "${DOWNLOAD_PATH}"
+ STATUS status
+ LOG log)
+
+ list(GET status 0 status_code)
+ list(GET status 1 status_string)
+
+ # Download success
+ if(status_code EQUAL 0)
+ break()
+ endif()
+
+ message(WARNING "error: downloading '${URL}' failed
status_code: ${status_code}
status_string: ${status_string}
log: ${log}")
- endif()
+
+ # Retry limit exceed
+ if(retry_count EQUAL 5)
+ message(FATAL_ERROR "Download ${PREFIX} from ${URL} - failed")
+ endif()
+
+ # Retry after 10 seconds when download fails
+ execute_process(COMMAND sleep 10)
+ endforeach()
+
message(STATUS "Download ${PREFIX} from ${URL} - done")
# Verify checksum
absl::numeric
absl::random_random
absl::strings
+ absl::status
absl::synchronization
absl::time
absl::utility
set(Abseil_FOUND TRUE PARENT_SCOPE)
endfunction(_Abseil_import)
+set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -fPIC")
+set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fPIC")
+set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -fPIC")
+set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC")
+
_Abseil_import()
--- /dev/null
+function(_Opencl_Headers_import)
+ nnas_find_package(Opencl_HeadersSource QUIET)
+
+ # NOTE This line prevents multiple definitions of target
+ if(TARGET Headers)
+ set(Opencl_HeadersSource_DIR ${Opencl_HeadersSource_DIR} PARENT_SCOPE)
+ set(Opencl_Headers_FOUND TRUE PARENT_SCOPE)
+ return()
+ endif(TARGET Headers)
+
+ if(NOT Opencl_HeadersSource_FOUND)
+ message(STATUS "Opencl_Headers: Source not found")
+ set(Opencl_Headers_FOUND FALSE PARENT_SCOPE)
+ return()
+ endif(NOT Opencl_HeadersSource_FOUND)
+
+ add_extdirectory("${Opencl_HeadersSource_DIR}" OPENCL_HEADERS EXCLUDE_FROM_ALL)
+ set(Opencl_Headers_DIR ${Opencl_HeadersSource_DIR} PARENT_SCOPE)
+ set(Opencl_Headers_FOUND TRUE PARENT_SCOPE)
+endfunction(_Opencl_Headers_import)
+
+_Opencl_Headers_import()
--- /dev/null
+function(_Opencl_HeadersSource_import)
+ if(NOT DOWNLOAD_OPENCL_HEADERS)
+ set(Opencl_HeadersSource_FOUND FALSE PARENT_SCOPE)
+ return()
+ endif(NOT DOWNLOAD_OPENCL_HEADERS)
+
+ nnas_include(ExternalSourceTools)
+ nnas_include(OptionTools)
+
+ envoption(EXTERNAL_DOWNLOAD_SERVER "https://github.com")
+ envoption(OPENCL_HEADERS_URL ${EXTERNAL_DOWNLOAD_SERVER}/KhronosGroup/OpenCL-Headers/archive/v2021.04.29.tar.gz)
+
+ ExternalSource_Download(OPENCL_HEADERS
+ DIRNAME OPENCL_HEADERS
+ URL ${OPENCL_HEADERS_URL}
+ CHECKSUM MD5=5a7ea04265119aa76b4ecbd95f258219)
+
+ set(Opencl_HeadersSource_DIR ${OPENCL_HEADERS_SOURCE_DIR} PARENT_SCOPE)
+ set(Opencl_HeadersSource_FOUND TRUE PARENT_SCOPE)
+endfunction(_Opencl_HeadersSource_import)
+
+_Opencl_HeadersSource_import()
FILES_TO_CHECK_PERMISSION=()
for f in ${FILES_TO_CHECK[@]}; do
# Manually ignore permission checking
- if [[ ${f} == !(nnas|nnfw|nncc|*.sh|*.py|*/gradlew) ]] || [[ ${f} == tests/nnapi/specs/**/*.py ]]; then
+ if [[ ${f} == !(nnas|nnfw|nncc|*.sh|*.py|*/gradlew|infra/debian/compiler/rules|infra/debian/runtime/rules) ]] \
+ || [[ ${f} == tests/nnapi/specs/**/*.py ]]; then
FILES_TO_CHECK_PERMISSION+=("${f}")
fi
done
"${LCOV_PATH}" -e "${RAW_COVERAGE_INFO_PATH}" -o "${EXTRACTED_COVERAGE_INFO_PATH}" \
"${CANDIDATES[@]}"
+
+opencl_files=($(find ./runtime/onert/backend/gpu_cl/open_cl/ \( -name "*.cc" -o -name "*.h" \) -exec realpath {} \; ))
+
# Exclude *.test.cpp files from coverage report
# Exclude flatbuffer generated files from coverage report
"${LCOV_PATH}" -r "${EXTRACTED_COVERAGE_INFO_PATH}" -o "${EXCLUDED_COVERAGE_INFO_PATH}" \
- '*.test.cpp' '*_schema_generated.h'
+ '*.test.cpp' '*_schema_generated.h' "${opencl_files[@]}"
# Final coverage data
cp -v ${EXCLUDED_COVERAGE_INFO_PATH} ${COVERAGE_INFO_PATH}
--- /dev/null
+one (1.17.0) bionic; urgency=medium
+
+ * More optimization pass
+ * Add new InstanceNorm pattern in `FuseInstanceNormPass`
+ * Add verbose option
+ * Introduce `onecc` driver to `one-cmds`
+ * Introduce `one-profile` driver to `one-cmds`
+
+ -- seongwoo <mhs4670go@naver.com> Fri, 20 Aug 2021 17:50:20 +0900
+
+one (1.16.1) bionic; urgency=medium
+
+ * Extends the point where `one-codegen` finds backends.
+
+ -- seongwoo chae <mhs4670go@naver.com> Wed, 26 May 2021 18:06:53 +0900
+
+one (1.16.0) bionic; urgency=low
+
+ * Initial release.
+
+ -- seongwoo chae <mhs4670go@naver.com> Mon, 26 Apr 2021 14:34:57 +0900
--- /dev/null
+Source: one
+Section: devel
+Priority: extra
+Maintainer: Neural Network Acceleration Solution Developers <nnfw@samsung.com>
+Build-Depends: cmake, debhelper (>=9), dh-python, python3-all
+Standards-Version: 3.9.8
+Homepage: https://github.com/Samsung/ONE
+
+Package: one-compiler
+Architecture: amd64
+Multi-Arch: foreign
+Depends: ${misc:Depends}, ${shlibs:Depends}, python3-venv, python3-pip
+Description: On-device Neural Engine compiler package
+
+Package: one-compiler-dev
+Architecture: amd64
+Multi-Arch: same
+Depends: one-compiler, ${shlibs:Depends}, ${misc:Depends}
+Description: one-compiler development package
+
+Package: one-compiler-test
+Architecture: amd64
+Multi-Arch: same
+Depends: one-compiler, ${shlibs:Depends}, ${misc:Depends}
+Description: one-compiler test package
--- /dev/null
+Files: *
+License: Proprietary
+Copyright (c) <2018> <Samsung Electronics Co.,Ltd.>
--- /dev/null
+.TH ONE-BUILD "1" "August 2021" "one-build version 1.17.0" "User Commands"
+.SH NAME
+one-build \- run ONE drivers
+.SH DESCRIPTION
+usage: one\-build [\-h] [\-v] [\-V] [\-C CONFIG]
+.PP
+\fBone\-build\fR is a command line tool that runs ONE drivers in customized order.
+.SS "Configuration file:"
+\fBone\-build\fR takes input as a configuration file that supports ini format.
+A configuration file consists of sections, each led by a [section] header.
+Each section is the ONE driver you want to run, and consists of commands in a key/value combination to pass to the driver.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.SH EXAMPLES
+Before you run \fBone\-build\fR, you must write a configuration file.
+.PP
+$ cat one-build.template.cfg
+.PP
+[one-build]
+.br
+one-import-tf=True
+.br
+one-import-tflite=False
+.br
+one-import-bcq=False
+.br
+one-import-onnx=False
+.br
+one-optimize=True
+.br
+one-quantize=False
+.br
+one-pack=True
+.br
+one-codegen=False
+.PP
+[one-import-tf]
+.br
+input_path=/path/to/inception_v3.pb
+.br
+output_path=inception_v3.circle
+.br
+input_arrays=input
+.br
+input_shapes=1,299,299,3
+.br
+output_arrays=InceptionV3/Predictions/Reshape_1
+.br
+converter_version=v1
+.br
+model_format=graph_def
+.PP
+[one-optimize]
+.br
+input_path=inception_v3.circle
+.br
+output_path=inception_v3.opt.circle
+.br
+generate_profile_data=False
+.PP
+[one-pack]
+.br
+input_path=inception_v3.opt.circle
+.br
+output_path=inception_v3_pack
+.PP
+\fBone\-build\fR section decides whether to use each driver or not.
+If the value is False, even if the corresponding section exists, the driver won't be executed.
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-build
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-build
+programs are properly installed at your site, the command
+.IP
+.B info one-build
+.PP
+should give you access to the complete manual.
+
--- /dev/null
+.TH ONE-CODEGEN "1" "August 2021" "one-codegen version 1.17.0" "User Commands"
+.SH NAME
+one-codegen \- geneate codes
+.SH DESCRIPTION
+usage: one\-codegen [\-h] [\-v] [\-C CONFIG] [\-b BACKEND] [\-\-] [COMMANDS FOR BACKEND]
+.PP
+\fBone\-codegen\fR is a command line tool for code generation.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-b\fR BACKEND, \fB\-\-backend\fR BACKEND
+backend name to use
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-codegen
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-codegen
+programs are properly installed at your site, the command
+.IP
+.B info one-codegen
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-IMPORT-BCQ "1" "August 2021" "one-import-bcq version 1.17.0" "User Commands"
+.SH NAME
+one-import-bcq \- convert TensorFlow with BCQ to circle
+.SH DESCRIPTION
+usage: one\-import\-bcq [\-h] [\-v] [\-V] [\-C CONFIG] [\-\-v1 | \-\-v2] [\-i INPUT_PATH]
+.br
+[\-o OUTPUT_PATH] [\-I INPUT_ARRAYS] [\-s INPUT_SHAPES]
+.br
+[\-O OUTPUT_ARRAYS]
+.PP
+\fBone\-import\-bcq\fR is a command line tool to convert TensorFlow with BCQ to circle.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-\-v1\fR
+use TensorFlow Lite Converter 1.x
+.TP
+\fB\-\-v2\fR
+use TensorFlow Lite Converter 2.x
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.TP
+\fB\-I\fR INPUT_ARRAYS, \fB\-\-input_arrays\fR INPUT_ARRAYS
+names of the input arrays, comma\-separated
+.TP
+\fB\-s\fR INPUT_SHAPES, \fB\-\-input_shapes\fR INPUT_SHAPES
+shapes corresponding to \fB\-\-input_arrays\fR, colon\-separated (ex:"1,4,4,3:1,20,20,3")
+.TP
+\fB\-O\fR OUTPUT_ARRAYS, \fB\-\-output_arrays\fR OUTPUT_ARRAYS
+names of the output arrays, comma\-separated
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-import-bcq
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-import-bcq
+programs are properly installed at your site, the command
+.IP
+.B info one-import-bcq
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-IMPORT-ONNX "1" "August 2021" "one-import-onnx version 1.17.0" "User Commands"
+.SH NAME
+one-import-onnx \- convert ONNX to circle
+.SH DESCRIPTION
+usage: one\-import\-onnx [\-h] [\-v] [\-V] [\-C CONFIG] [\-i INPUT_PATH]
+.br
+[\-o OUTPUT_PATH] [\-I INPUT_ARRAYS] [\-O OUTPUT_ARRAYS]
+.br
+[\-\-model_format MODEL_FORMAT]
+.br
+[\-\-converter_version CONVERTER_VERSION]
+.br
+[\-\-save_intermediate]
+.PP
+\fBone\-import\-onnx\fR is a command line tool to convert ONNX to circle.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-\-save_intermediate\fR
+Save intermediate files to output folder
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.TP
+\fB\-I\fR INPUT_ARRAYS, \fB\-\-input_arrays\fR INPUT_ARRAYS
+names of the input arrays, comma\-separated
+.TP
+\fB\-O\fR OUTPUT_ARRAYS, \fB\-\-output_arrays\fR OUTPUT_ARRAYS
+names of the output arrays, comma\-separated
+.HP
+\fB\-\-model_format\fR MODEL_FORMAT
+.HP
+\fB\-\-converter_version\fR CONVERTER_VERSION
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-import-onnx
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-import-onnx
+programs are properly installed at your site, the command
+.IP
+.B info one-import-onnx
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-IMPORT-TF "1" "August 2021" "one-import-tf version 1.17.0" "User Commands"
+.SH NAME
+one-import-tf \- convert TensorFlow to circle
+.SH DESCRIPTION
+usage: one\-import\-tf [\-h] [\-v] [\-V] [\-C CONFIG] [\-\-v1 | \-\-v2]
+.br
+[\-\-graph_def | \-\-saved_model | \-\-keras_model]
+.br
+[\-i INPUT_PATH] [\-o OUTPUT_PATH] [\-I INPUT_ARRAYS]
+.br
+[\-s INPUT_SHAPES] [\-O OUTPUT_ARRAYS]
+.br
+[\-\-save_intermediate]
+.PP
+\fBone\-import\-tf\fR is a command line tool to convert TensorFlow model to circle.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-\-save_intermediate\fR
+Save intermediate files to output folder
+.TP
+\fB\-\-v1\fR
+use TensorFlow Lite Converter 1.x
+.TP
+\fB\-\-v2\fR
+use TensorFlow Lite Converter 2.x
+.TP
+\fB\-\-graph_def\fR
+use graph def file(default)
+.TP
+\fB\-\-saved_model\fR
+use saved model
+.TP
+\fB\-\-keras_model\fR
+use keras model
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.TP
+\fB\-I\fR INPUT_ARRAYS, \fB\-\-input_arrays\fR INPUT_ARRAYS
+names of the input arrays, comma\-separated
+.TP
+\fB\-s\fR INPUT_SHAPES, \fB\-\-input_shapes\fR INPUT_SHAPES
+shapes corresponding to \fB\-\-input_arrays\fR, colon\-separated (ex:"1,4,4,3:1,20,20,3")
+.TP
+\fB\-O\fR OUTPUT_ARRAYS, \fB\-\-output_arrays\fR OUTPUT_ARRAYS
+names of the output arrays, comma\-separated
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-import-tf
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-import-tf
+programs are properly installed at your site, the command
+.IP
+.B info one-import-tf
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-IMPORT-TFLITE "1" "August 2021" "one-import-tflite version 1.17.0" "User Commands"
+.SH NAME
+one-import-tflite \- convert TensorFlow lite to circle
+.SH DESCRIPTION
+usage: one\-import\-tflite [\-h] [\-v] [\-V] [\-C CONFIG] [\-i INPUT_PATH]
+.br
+[\-o OUTPUT_PATH]
+.PP
+\fBone\-import\-tflite\fR is a command line tool to convert TensorFlow lite to circle.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-import-tflite
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-import-tflite
+programs are properly installed at your site, the command
+.IP
+.B info one-import-tflite
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-IMPORT "1" "August 2021" "one-import version 1.17.0" "User Commands"
+.SH NAME
+one-import \- convert various format to circle
+.SH SYNOPSIS
+usage: one\-import [\-h] [\-C CONFIG] [\-v] driver
+.SH DESCRIPTION
+\fBone\-import\fR is a command line tool to convert various format to circle.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fBdriver\fR driver name to run (supported: tf, tflite, bcq, onnx)
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-import
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-import
+programs are properly installed at your site, the command
+.IP
+.B info one-import
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-OPTIMIZE "1" "August 2021" "one-optimize version 1.17.0" "User Commands"
+.SH NAME
+one-optimize \- optimize circle model
+.SH DESCRIPTION
+usage: one\-optimize [\-h] [\-v] [\-V] [\-C CONFIG] [\-p]
+.br
+[\-\-change_outputs CHANGE_OUTPUTS] [\-i INPUT_PATH]
+.br
+[\-o OUTPUT_PATH] [\-\-O1] [\-\-convert_nchw_to_nhwc]
+.br
+[\-\-nchw_to_nhwc_input_shape] [\-\-nchw_to_nhwc_output_shape]
+.br
+[\-\-fold_add_v2] [\-\-fold_cast] [\-\-fold_dequantize]
+.br
+[\-\-fold_sparse_to_dense] [\-\-forward_reshape_to_unaryop]
+.br
+[\-\-fuse_add_with_tconv] [\-\-fuse_batchnorm_with_conv]
+.br
+[\-\-fuse_batchnorm_with_dwconv]
+.br
+[\-\-fuse_batchnorm_with_tconv] [\-\-fuse_bcq]
+.br
+[\-\-fuse_preactivation_batchnorm]
+.br
+[\-\-make_batchnorm_gamma_positive]
+.br
+[\-\-fuse_activation_function] [\-\-fuse_instnorm]
+.br
+[\-\-replace_cw_mul_add_with_depthwise_conv]
+.br
+[\-\-remove_fakequant] [\-\-remove_quantdequant]
+.br
+[\-\-remove_redundant_reshape]
+.br
+[\-\-remove_redundant_transpose]
+.br
+[\-\-remove_unnecessary_reshape]
+.br
+[\-\-remove_unnecessary_slice]
+.br
+[\-\-remove_unnecessary_strided_slice]
+.br
+[\-\-remove_unnecessary_split] [\-\-resolve_customop_add]
+.br
+[\-\-resolve_customop_batchmatmul]
+.br
+[\-\-resolve_customop_matmul]
+.br
+[\-\-shuffle_weight_to_16x1float32]
+.br
+[\-\-substitute_pack_to_reshape]
+.br
+[\-\-substitute_squeeze_to_reshape]
+.br
+[\-\-substitute_transpose_to_reshape]
+.br
+[\-\-transform_min_max_to_relu6]
+.br
+[\-\-transform_min_relu_to_relu6]
+.PP
+\fBone\-optimize\fR is a command line tool to optimize circle model.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.SS "arguments for utility:"
+.TP
+\fB\-p\fR, \fB\-\-generate_profile_data\fR
+generate profiling data
+.TP
+\fB\-\-change_outputs\fR CHANGE_OUTPUTS
+Experimental: Change first subgraph output nodes to
+CSV names
+.SS "arguments for optimization:"
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.TP
+\fB\-\-O1\fR
+enable O1 optimization pass
+.TP
+\fB\-\-convert_nchw_to_nhwc\fR
+Experimental: This will convert NCHW operators to NHWC
+under the assumption that input model is NCHW.
+.TP
+\fB\-\-nchw_to_nhwc_input_shape\fR
+convert the input shape of the model (argument for
+convert_nchw_to_nhwc)
+.TP
+\fB\-\-nchw_to_nhwc_output_shape\fR
+convert the output shape of the model (argument for
+convert_nchw_to_nhwc)
+.TP
+\fB\-\-fold_add_v2\fR
+fold AddV2 op with constant inputs
+.TP
+\fB\-\-fold_cast\fR
+fold Cast op with constant input
+.TP
+\fB\-\-fold_dequantize\fR
+fold Dequantize op
+.TP
+\fB\-\-fold_sparse_to_dense\fR
+fold SparseToDense op
+.TP
+\fB\-\-forward_reshape_to_unaryop\fR
+Forward Reshape op
+.TP
+\fB\-\-fuse_add_with_tconv\fR
+fuse Add op to Transposed Convolution op
+.TP
+\fB\-\-fuse_batchnorm_with_conv\fR
+fuse BatchNorm op to Convolution op
+.TP
+\fB\-\-fuse_batchnorm_with_dwconv\fR
+fuse BatchNorm op to Depthwise Convolution op
+.TP
+\fB\-\-fuse_batchnorm_with_tconv\fR
+fuse BatchNorm op to Transposed Convolution op
+.TP
+\fB\-\-fuse_bcq\fR
+apply Binary Coded Quantization
+.TP
+\fB\-\-fuse_preactivation_batchnorm\fR
+fuse BatchNorm operators of pre\-activations to
+Convolution op
+.TP
+\fB\-\-make_batchnorm_gamma_positive\fR
+make negative gamma of BatchNorm to a small positive
+value (1e\-10). Note that this pass can change the
+execution result of the model. So, use it only when
+the impact is known to be acceptable.
+.TP
+\fB\-\-fuse_activation_function\fR
+fuse Activation function to a preceding operator
+.TP
+\fB\-\-fuse_instnorm\fR
+fuse ops to InstanceNorm operator
+.TP
+\fB\-\-replace_cw_mul_add_with_depthwise_conv\fR
+replace channel\-wise Mul/Add with DepthwiseConv2D
+.TP
+\fB\-\-remove_fakequant\fR
+remove FakeQuant ops
+.TP
+\fB\-\-remove_quantdequant\fR
+remove Quantize\-Dequantize sequence
+.TP
+\fB\-\-remove_redundant_reshape\fR
+fuse or remove subsequent Reshape ops
+.TP
+\fB\-\-remove_redundant_transpose\fR
+fuse or remove subsequent Transpose ops
+.TP
+\fB\-\-remove_unnecessary_reshape\fR
+remove unnecessary reshape ops
+.TP
+\fB\-\-remove_unnecessary_slice\fR
+remove unnecessary slice ops
+.TP
+\fB\-\-remove_unnecessary_strided_slice\fR
+remove unnecessary strided slice ops
+.TP
+\fB\-\-remove_unnecessary_split\fR
+remove unnecessary split ops
+.TP
+\fB\-\-resolve_customop_add\fR
+convert Custom(Add) op to Add op
+.TP
+\fB\-\-resolve_customop_batchmatmul\fR
+convert Custom(BatchMatmul) op to BatchMatmul op
+.TP
+\fB\-\-resolve_customop_matmul\fR
+convert Custom(Matmul) op to Matmul op
+.TP
+\fB\-\-shuffle_weight_to_16x1float32\fR
+convert weight format of FullyConnected op to
+SHUFFLED16x1FLOAT32. Note that it only converts
+weights whose row is a multiple of 16
+.TP
+\fB\-\-substitute_pack_to_reshape\fR
+convert single input Pack op to Reshape op
+.TP
+\fB\-\-substitute_squeeze_to_reshape\fR
+convert certain condition Squeeze to Reshape
+.TP
+\fB\-\-substitute_transpose_to_reshape\fR
+convert certain condition Transpose to Reshape
+.TP
+\fB\-\-transform_min_max_to_relu6\fR
+transform Minimum\-Maximum pattern to Relu6 op
+.TP
+\fB\-\-transform_min_relu_to_relu6\fR
+transform Minimum(6)\-Relu pattern to Relu6 op
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-optimize
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-optimize
+programs are properly installed at your site, the command
+.IP
+.B info one-optimize
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-PACK "1" "August 2021" "one-pack version 1.17.0" "User Commands"
+.SH NAME
+one-pack \- package circle and metadata into nnpackage
+.SH DESCRIPTION
+usage: one\-pack [\-h] [\-v] [\-V] [\-C CONFIG] [\-i INPUT_PATH] [\-o OUTPUT_PATH]
+.PP
+\fBone\-pack\fR is a command line tool to package circle and metadata into nnpackage.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-pack
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-pack
+programs are properly installed at your site, the command
+.IP
+.B info one-pack
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-PROFILE "1" "August 2021" "one-profile version 1.17.0" "User Commands"
+.SH NAME
+one-profile \- profile backend model file
+.SH DESCRIPTION
+usage: one\-profile [\-h] [\-v] [\-V] [\-C CONFIG] [\-b BACKEND] [\-\-] [COMMANDS FOR BACKEND]
+.PP
+\fBone\-profile\fR is a command line tool for profiling backend model.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-b\fR BACKEND, \fB\-\-backend\fR BACKEND
+backend name to use
+.SH COPYRIGHT
+Copyright \(co 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-profile
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-profile
+programs are properly installed at your site, the command
+.IP
+.B info one-profile
+.PP
+should give you access to the complete manual.
--- /dev/null
+.TH ONE-QUANTIZE "1" "August 2021" "one-quantize version 1.17.0" "User Commands"
+.SH NAME
+one-quantize \- quantize circle model
+.SH DESCRIPTION
+usage: one\-quantize [\-h] [\-v] [\-V] [\-C CONFIG] [\-i INPUT_PATH] [\-d INPUT_DATA]
+.br
+[\-f INPUT_DATA_FORMAT] [\-o OUTPUT_PATH] [\-p]
+.br
+[\-\-input_dtype INPUT_DTYPE]
+.br
+[\-\-quantized_dtype QUANTIZED_DTYPE]
+.br
+[\-\-granularity GRANULARITY]
+.br
+[\-\-min_percentile MIN_PERCENTILE]
+.br
+[\-\-max_percentile MAX_PERCENTILE] [\-\-mode MODE]
+.PP
+\fBone\-quantize\fR is a command line tool to quantize circle model.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.TP
+\fB\-i\fR INPUT_PATH, \fB\-\-input_path\fR INPUT_PATH
+full filepath of the input file
+.TP
+\fB\-d\fR INPUT_DATA, \fB\-\-input_data\fR INPUT_DATA
+full filepath of the input data file. if not
+specified, run with random input data.
+.TP
+\fB\-o\fR OUTPUT_PATH, \fB\-\-output_path\fR OUTPUT_PATH
+full filepath of the output file
+.TP
+\fB\-p\fR, \fB\-\-generate_profile_data\fR
+generate profiling data
+.SS "arguments for quantization:"
+.TP
+\fB\-\-input_dtype\fR INPUT_DTYPE
+input data type (supported: float32, default=float32)
+.TP
+\fB\-\-quantized_dtype\fR QUANTIZED_DTYPE
+output quantized data type (supported: uint8, int16,
+default=uint8)
+.TP
+\fB\-\-granularity\fR GRANULARITY
+quantize granularity (supported: layer, channel,
+default=layer)
+.TP
+\fB\-\-min_percentile\fR MIN_PERCENTILE
+minimum percentile (0.0~100.0, default=1.0)
+.TP
+\fB\-\-max_percentile\fR MAX_PERCENTILE
+maximum percentile (0.0~100.0, default=99.0)
+.TP
+\fB\-\-mode\fR MODE
+record mode (supported: percentile/moving_average,
+default=percentile)
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B one-quantize
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B one-quantize
+programs are properly installed at your site, the command
+.IP
+.B info one-quantize
+.PP
+should give you access to the complete manual.
--- /dev/null
+.\" Manpage for onecc.
+.\" Contact nnfw@samsung.com to correct errors or typos.
+.TH ONECC "1" "August 2021" "onecc version 1.17.0" "User Commands"
+.SH NAME
+onecc \- run ONE driver via several commands or configuration file
+.SH SYNOPSIS
+\fBonecc\fR [\-h] [\-v] [\-C CONFIG] [COMMAND <args>]
+.SH DESCRIPTION
+\fBonecc\fR is a command line tool to execute ONE driver via several commands or configuration file.
+.SS "Configuration file:"
+\fBonecc\fR takes input as a configuration file that supports ini format.
+A configuration file consists of sections, each led by a [section] header.
+Each section is the ONE driver you want to run, and consists of commands in a key/value combination to pass to the driver.
+.SH OPTIONS
+.TP
+\fB\-h\fR, \fB\-\-help\fR
+show this help message and exit
+.TP
+\fB\-v\fR, \fB\-\-version\fR
+show program's version number and exit
+.TP
+\fB\-V\fR, \fB\-\-verbose\fR
+output additional information to stdout or stderr
+.TP
+\fB\-C\fR CONFIG, \fB\-\-config\fR CONFIG
+run with configuation file
+.SS compile to circle model
+.TP
+\fBimport\fR
+Convert given model to circle. See one\-import(1) for details.
+.TP
+\fBoptimize\fR
+Optimize circle model. See one-optimize(1) for details.
+.TP
+\fBquantize\fR
+Quantize circle model. See one-quantize(1) for details.
+.SS package circle model
+.TP
+\fBpack\fR
+Package circle and metadata into nnpackage. See one-pack(1) for details.
+.SS run backend tools
+.TP
+\fBcodegen\fR
+Code generation tool. See one-codegen(1) for details.
+.TP
+\fBprofile\fR
+Profile backend model file. See one-profile(1) for details.
+.SH EXAMPLES
+.SS Use command line interface
+.TP
+\fBonecc import tf --v1 -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR \fB-I\fR \fIinput_arrays\fR \fB-s\fR \fIinput_shapes\fR \fB-O\fR \fIoutput_arrays\fR
+import tf model
+.TP
+\fBonecc import tflite -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR
+import tflite model
+.TP
+\fBonecc import onnx -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR
+import onnx model
+.TP
+\fBonecc optimize -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR \fIoptimize_arguments\fR
+optimize circle model
+.TP
+\fBonecc quantize -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR \fB-d\fR \fIinput_data\fR
+quantize circle model
+.TP
+\fBonecc pack -i\fR \fIinput_path\fR \fB-o\fR \fIoutput_path\fR
+package circle and metadata into nnpackage
+.TP
+\fBonecc codegen -b\fR \fIbackend\fR \fB--\fR \fIbackends_arguments\fR
+generate backend code
+.TP
+\fBonecc profile -b\fR \fIbackend\fR \fB--\fR \fIbackends_arguments\fR
+profile backend model
+.PP
+.SS Use configuration file
+.PP
+The configuration file should be written in the following format:
+.IP
+[onecc]
+.br
+one-import-tf=True
+.br
+one-import-tflite=False
+.br
+one-import-bcq=False
+.br
+one-import-onnx=False
+.br
+one-optimize=True
+.br
+one-quantize=True
+.br
+one-pack=True
+.br
+one-codegen=True
+.br
+one-profile=True
+.IP
+[one-import-tf]
+.br
+input_path=/path/to/inception_v3.pb
+.br
+output_path=inception_v3.circle
+.br
+input_arrays=input
+.br
+input_shapes=1,299,299,3
+.br
+output_arrays=InceptionV3/Predictions/Reshape_1
+.br
+converter_version=v1
+.br
+model_format=graph_def
+.IP
+[one-optimize]
+.br
+input_path=inception_v3.circle
+.br
+output_path=inception_v3.opt.circle
+.br
+generate_profile_data=False
+.IP
+[one-quantize]
+.br
+input_path=inception_v3.opt.circle
+.br
+output_path=inception_v3.quantized.circle
+.br
+input_data=inception_v3_test_data.h5
+.IP
+[one-pack]
+.br
+input_path=inception_v3.quantized.circle
+.br
+output_path=inception_v3_pack
+.IP
+[one-codegen]
+.br
+backend=dummy
+.br
+command=-o sample.out inception_v3.quantized.circle
+.IP
+[one-profile]
+.br
+backend=dummy
+.br
+command=sample.out
+.TP
+\fBonecc -C\fR \fIconfiguration file\fR
+Run ONE driver according to configuration section parameter
+.PP
+\fBonecc\fR section decides whether to use each driver or not.
+If the value is False, even if the corresponding section exists, the driver won't be executed.
+.SH COPYRIGHT
+Copyright \(co 2020\-2021 Samsung Electronics Co., Ltd. All Rights Reserved
+Licensed under the Apache License, Version 2.0
+https://github.com/Samsung/ONE
+.SH "SEE ALSO"
+The full documentation for
+.B onecc
+is maintained as a Texinfo manual. If the
+.B info
+and
+.B onecc
+programs are properly installed at your site, the command
+.IP
+.B info onecc
+.PP
+should give you access to the complete manual.
+
--- /dev/null
+# {FILES_TO_INSTALL} {DEST_DIR}
+# bin
+usr/bin/circledump usr/share/one/bin/
+usr/bin/circle-tensordump usr/share/one/bin/
+usr/bin/tflchef usr/share/one/bin/
+usr/bin/tflchef-file usr/share/one/bin/
+usr/bin/tflchef-reverse usr/share/one/bin/
+# include
+usr/include/* usr/share/one/include/
--- /dev/null
+# bin
+usr/share/one/bin/circledump usr/bin/circledump
+usr/share/one/bin/circle-tensordump usr/bin/circle-tensordump
+usr/share/one/bin/tflchef usr/bin/tflchef
+usr/share/one/bin/tflchef-file usr/bin/tflchef-file
+usr/share/one/bin/tflchef-reverse usr/bin/tflchef-reverse
--- /dev/null
+# {FILES_TO_INSTALL} {DEST_DIR}
+# bin
+usr/bin/luci_eval_driver usr/share/one/bin/
+# test
+usr/test/* usr/share/one/test/
--- /dev/null
+# {FILES_TO_INSTALL} {DEST_DIR}
+# bin
+usr/bin/circle2circle usr/share/one/bin/
+usr/bin/circle_partitioner usr/share/one/bin/
+usr/bin/circle-quantizer usr/share/one/bin/
+usr/bin/conv_mixin_1.8.0.patch usr/share/one/bin/
+usr/bin/generate_bcq_metadata.py usr/share/one/bin/
+usr/bin/generate_bcq_output_arrays.py usr/share/one/bin/
+usr/bin/model2nnpkg.sh usr/share/one/bin/
+usr/bin/onecc usr/share/one/bin/
+usr/bin/onecc.template.cfg usr/share/one/bin/
+usr/bin/one-build usr/share/one/bin/
+usr/bin/one-build.template.cfg usr/share/one/bin/
+usr/bin/one-codegen usr/share/one/bin/
+usr/bin/one-import usr/share/one/bin/
+usr/bin/one-import-bcq usr/share/one/bin/
+usr/bin/one-import-onnx usr/share/one/bin/
+usr/bin/one-import-tf usr/share/one/bin/
+usr/bin/one-import-tflite usr/share/one/bin/
+usr/bin/one-optimize usr/share/one/bin/
+usr/bin/one-pack usr/share/one/bin/
+usr/bin/one-prepare-venv usr/share/one/bin/
+usr/bin/one-profile usr/share/one/bin/
+usr/bin/one-quantize usr/share/one/bin/
+usr/bin/one-version usr/share/one/bin/
+usr/bin/rawdata2hdf5 usr/share/one/bin/
+usr/bin/record-minmax usr/share/one/bin/
+usr/bin/tf2nnpkg usr/share/one/bin/
+usr/bin/tf2tfliteV2.py usr/share/one/bin/
+usr/bin/tflite2circle usr/share/one/bin/
+usr/bin/utils.py usr/share/one/bin/
+# lib
+usr/lib/* usr/share/one/lib/
+# doc
+usr/doc/* usr/share/one/doc/
--- /dev/null
+# bin
+usr/share/one/bin/one-build usr/bin/one-build
+usr/share/one/bin/onecc usr/bin/onecc
+# lib
+usr/share/one/lib/libloco.so usr/lib/libloco.so
+usr/share/one/lib/libluci_env.so usr/lib/libluci_env.so
+usr/share/one/lib/libluci_export.so usr/lib/libluci_export.so
+usr/share/one/lib/libluci_import.so usr/lib/libluci_import.so
+usr/share/one/lib/libluci_interpreter.so usr/lib/libluci_interpreter.so
+usr/share/one/lib/libluci_lang.so usr/lib/libluci_lang.so
+usr/share/one/lib/libluci_logex.so usr/lib/libluci_logex.so
+usr/share/one/lib/libluci_log.so usr/lib/libluci_log.so
+usr/share/one/lib/libluci_partition.so usr/lib/libluci_partition.so
+usr/share/one/lib/libluci_pass.so usr/lib/libluci_pass.so
+usr/share/one/lib/libluci_profile.so usr/lib/libluci_profile.so
+usr/share/one/lib/libluci_service.so usr/lib/libluci_service.so
--- /dev/null
+debian/docs/one-build.1
+debian/docs/one-codegen.1
+debian/docs/one-import.1
+debian/docs/one-import-bcq.1
+debian/docs/one-import-onnx.1
+debian/docs/one-import-tf.1
+debian/docs/one-import-tflite.1
+debian/docs/one-optimize.1
+debian/docs/one-pack.1
+debian/docs/one-profile.1
+debian/docs/one-quantize.1
+debian/docs/onecc.1
--- /dev/null
+#!/bin/bash
+
+# https://www.debian.org/doc/debian-policy/ch-maintainerscripts.html
+# Boradly speaking, the `postinst` is called after a package is unpacked.
+
+set -e
+
+# This script is invoked as root except environmental variables,
+# which causes invalid permission problem.
+# e.g. When `pip` installs user packages, it proceeds based on $HOME.
+# To proper installation, $HOME should be root.
+su - $(whoami) -c '/usr/share/one/bin/one-prepare-venv' # $(whoami) = root
--- /dev/null
+#!/bin/bash
+
+set -e
+
+case "$1" in
+ remove|purge)
+ rm -rf /usr/share/one/
+ ;;
+ upgrade)
+ # DO NOTHING
+ ;;
+ failed-upgrade|abort-install|abort-upgrade)
+ # DO NOTHING
+ ;;
+ *)
+ # DO NOTHING
+ ;;
+esac
--- /dev/null
+#!/usr/bin/make -f
+export DH_VERBOSE = 1
+export NNAS_BUILD_PREFIX = build
+export PRESET = 20210706
+export _DESTDIR = debian/tmp/usr
+
+%:
+ dh $@
+
+override_dh_auto_build:
+ ./nnas create-package --preset $(PRESET) --prefix "$(_DESTDIR)"
+
+override_dh_auto_install:
+ cmake --build "$(NNAS_BUILD_PREFIX)/nncc" -- install
+
+override_dh_install:
+ install -t "$(_DESTDIR)/bin" -D "tools/nnpackage_tool/model2nnpkg/model2nnpkg.sh"
+ install -T -m 755 -D "infra/packaging/res/tf2nnpkg.${PRESET}" "$(_DESTDIR)/bin/tf2nnpkg"
+ dh_install
+
+override_dh_builddeb:
+ dh_builddeb --destdir=$(NNAS_BUILD_PREFIX)
--- /dev/null
+3.0 (native)
--- /dev/null
+# This is for reproducible building. Otherwise, `debuild` recognizes build artifacts as source files.
+diff-ignore="build|externals"
--- /dev/null
+one (1.17.0) bionic; urgency=low
+
+ * New gpu_gl backend supports the following operations : Add, Convolution, Depthwise Convolution, Pooling, Reshape, Relu, Softmax
+
+ -- Chunseok Lee <chunseok.lee@samsung.com> Fri, 20 Aug 2021 17:00:00 +0900
+
+one (1.16.0) bionic; urgency=low
+
+ * Initial release.
+
+ -- Chunseok Lee <chunseok.lee@samsung.com> Mon, 05 Jul 2021 17:11:00 +0900
--- /dev/null
+Source: one
+Section: devel
+Priority: extra
+Maintainer: Neural Network Acceleration Solution Developers <nnfw@samsung.com>
+Build-Depends: cmake, debhelper (>=9), dh-python, python3-all
+Standards-Version: 3.9.8
+Homepage: https://github.com/Samsung/ONE
+
+Package: nnfw
+Architecture: amd64
+Multi-Arch: same
+Depends: ${shlibs:Depends}, ${misc:Depends}
+Description: one-runtime package
+
+Package: nnfw-dev
+Architecture: amd64
+Multi-Arch: same
+Depends: nnfw, ${shlibs:Depends}, ${misc:Depends}
+Description: one-runtime development package
--- /dev/null
+Files: *
+License: Proprietary
+Copyright (c) <2018> <Samsung Electronics Co.,Ltd.>
--- /dev/null
+# {FILES_TO_INSTALL} {DEST_DIR}
+# include
+usr/include/nnfw usr/include/
+usr/lib/pkgconfig/*.pc usr/lib/pkgconfig/
--- /dev/null
+# {FILES_TO_INSTALL} {DEST_DIR}
+# lib
+usr/lib/*.so usr/lib/
--- /dev/null
+#!/usr/bin/make -f
+DEBVER := $(shell dpkg-parsechangelog -SVersion)
+export DH_VERBOSE = 1
+export _DESTDIR = debian/tmp/
+export BUILD_TYPE=release
+export OPTIONS=-DBUILD_LOGGING=0 -DBUILD_TFLITE_COMPARATOR_TEST_TOOL=0 -DBUILD_NNPACKAGE_RUN=0 -DBUILD_TFLITE_RUN=0 -DBUILD_NNAPI_TEST=0 -DBUILD_RUNTIME_NNAPI_TEST=0 -DBUILD_TFLITE_BENCHMARK_MODEL=0 -DBUILD_TFLITE_VANILLA_RUN=0 -DBUILD_TENSORFLOW_LITE_2_3_0=0 -DBUILD_TENSORFLOW_LITE=0
+export DEBIAN_BUILD=1
+export INSTALL_PATH=debian/tmp/usr/
+%:
+ dh $@
+
+override_dh_auto_build:
+ make -f Makefile.template
+override_dh_auto_install:
+ make -f Makefile.template install
+override_dh_install:
+ install -d debian/tmp/usr/lib/pkgconfig
+ sed -i 's:@libdir@:\/usr\/lib:g' ./packaging/nnfw.pc.in
+ sed -i 's:@includedir@:\/usr\/include:g' ./packaging/nnfw.pc.in
+ sed -i 's:@version@:${DEBVER}:g' ./packaging/nnfw.pc.in
+ install -m 0644 packaging/nnfw.pc.in -T debian/tmp/usr/lib/pkgconfig/nnfw.pc
+ dh_install
--- /dev/null
+3.0 (native)
--- /dev/null
+# This is for reproducible building. Otherwise, `debuild` recognizes build artifacts as source files.
+diff-ignore="build|externals"
# Install 'add-apt-repository'
RUN apt-get update && apt-get -qqy install software-properties-common
+# Git repo for latest version (github checkout@v2 action requires v2.18)
+RUN add-apt-repository ppa:git-core/ppa -y
+
# Build tool
RUN apt-get update && apt-get -qqy install build-essential cmake scons git g++-arm-linux-gnueabihf g++-aarch64-linux-gnu
+# ARM none eabi build tool
+RUN apt-get update && apt-get -qqy install gcc-arm-none-eabi
+
+# Debian build tool
+RUN apt-get update && apt-get -qqy install fakeroot devscripts debhelper python3-all
+
# Install extra dependencies (Caffe, nnkit)
RUN apt-get update && apt-get -qqy install libboost-all-dev libgflags-dev libgoogle-glog-dev libatlas-base-dev libhdf5-dev
# Build tool
RUN apt-get update && apt-get -qqy install build-essential cmake scons git lcov g++-arm-linux-gnueabihf g++-aarch64-linux-gnu
+# Debian build tool
+RUN apt-get update && apt-get -qqy install fakeroot devscripts debhelper python3-all dh-python
+
# Install extra dependencies (Caffe, nnkit)
RUN apt-get update && apt-get -qqy install libboost-all-dev libgflags-dev libgoogle-glog-dev libatlas-base-dev libhdf5-dev
option(DOWNLOAD_PYTORCH "Download Pytorch source" ON)
option(DOWNLOAD_ONNX "Download ONNX source" ON)
option(DOWNLOAD_ABSEIL "Download Abseil-cpp source" ON)
+option(DOWNLOAD_OPENCL_HEADERS "Download OpenCl Header source" ON)
option(DOWNLOAD_PYBIND11 "Download Pybind11 source" ON)
option(DOWNLOAD_GTEST "Download Google Test source" ON)
option(GENERATE_RUNTIME_NNAPI_TESTS "Generate NNAPI operation gtest" ON)
option(ENVVAR_ONERT_CONFIG "Use environment variable for onert configuration" ON)
option(INSTALL_TEST_SCRIPTS "Install test scripts" ON)
+option(BUILD_GPU_CL "Build gpu_cl backend" ON)
#
# Default build configuration for contrib
#
#
option(DOWNLOAD_TENSORFLOW "Download Tensorflow source" ON)
option(DOWNLOAD_ABSEIL "Download Abseil source" ON)
+option(DOWNLOAD_OPENCL_HEADERS "Download Opencl_headers source" ON)
option(DOWNLOAD_EIGEN "Download Eigen source" ON)
option(DOWNLOAD_FARMHASH "Download farmhash source" ON)
option(DOWNLOAD_GEMMLOWP "Download GEMM low precesion library source" ON)
option(BUILD_ARMCOMPUTE "Build ARM Compute from the downloaded source" OFF)
option(DOWNLOAD_ARMCOMPUTE "Download ARM Compute source" OFF)
option(BUILD_XNNPACK "Build XNNPACK" OFF)
+option(DOWNLOAD_OPENCL_HEADERS "Download opencl headers" OFF)
+option(BUILD_GPU_CL "Build gpu_cl backend" OFF)
option(BUILD_ARMCOMPUTE "Build ARM Compute from the downloaded source" OFF)
option(DOWNLOAD_ARMCOMPUTE "Download ARM Compute source" OFF)
option(BUILD_XNNPACK "Build XNNPACK" OFF)
+option(DOWNLOAD_OPENCL_HEADERS "Download opencl headers" OFF)
+option(BUILD_GPU_CL "Build gpu_cl backend" OFF)
option(ENVVAR_ONERT_CONFIG "Use environment variable for onert configuration" OFF)
option(BUILD_XNNPACK "Build XNNPACK" OFF)
+option(DOWNLOAD_OPENCL_HEADERS "Download opencl headers" OFF)
+option(BUILD_GPU_CL "Build gpu_cl backend" OFF)
file(GLOB TFLITE_PROFILING_TESTS "${TENSORFLOW_LITE_BASE}/profiling/*test*.cc")
list(REMOVE_ITEM TFLITE_PROFILING_SRCS ${TFLITE_PROFILING_TESTS})
-# We will use our own BuiltinOpResolver
-list(REMOVE_ITEM TFLITE_KERNEL_SRCS "${TENSORFLOW_LITE_BASE}/kernels/register.cc")
# We will use our own summarizer
list(REMOVE_ITEM TFLITE_PROFILING_SRCS "${TENSORFLOW_LITE_BASE}/profiling/profile_summarizer.cc")
list(APPEND TFLITE_SRCS ${TFLITE_CORE_SRCS})
list(APPEND TFLITE_SRCS "${FarmhashSource_DIR}/src/farmhash.cc")
+# externals for spectrogram
+list(APPEND TFLITE_SRCS "${OouraFFTSource_DIR}/fftsg.c")
+list(APPEND TFLITE_SRCS "${OouraFFTSource_DIR}/fftsg2d.c")
+
list(APPEND TFLITE_INCLUDES "${TensorFlowSource_DIR}")
list(APPEND TFLITE_INCLUDES "${AbseilSource_DIR}")
list(APPEND TFLITE_INCLUDES "${GEMMLowpSource_DIR}")
list(APPEND TFLITE_INCLUDES "${NEON2SSESource_DIR}")
endif(NEON2SSESource_FOUND)
-# This kernels are not used on nnfw
-## spectrogram
-list(REMOVE_ITEM TFLITE_SRCS "${TENSORFLOW_LITE_BASE}/kernels/audio_spectrogram.cc")
-list(REMOVE_ITEM TFLITE_SRCS "${TENSORFLOW_LITE_BASE}/kernels/audio_spectrogram_test.cc")
-list(REMOVE_ITEM TFLITE_SRCS "${TENSORFLOW_LITE_BASE}/kernels/internal/spectrogram.cc")
-
add_library(tensorflow-lite STATIC ${TFLITE_SRCS})
target_include_directories(tensorflow-lite SYSTEM PUBLIC ${TFLITE_INCLUDES})
target_compile_definitions(tensorflow-lite PUBLIC "GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK")
return_unless(GEMMLowpSource_FOUND)
nnas_find_package(TensorFlowSource EXACT 1.13.1 QUIET)
return_unless(TensorFlowSource_FOUND)
+ nnas_find_package(OouraFFTSource QUIET)
+ return_unless(OouraFFTSource_FOUND)
# Optional packages
nnas_find_package(NEON2SSESource QUIET)
fi
# The default preset
-PRESET="20210406"
+PRESET="20210706"
EXTRA_OPTIONS=()
while [ "$#" -ne 0 ]; do
REQUIRED_UNITS+=("luci")
# Tools
REQUIRED_UNITS+=("tflite2circle" "circle2circle" "tflchef" "circlechef")
+ REQUIRED_UNITS+=("circle-tensordump" "circledump")
REQUIRED_UNITS+=("tf2tfliteV2" "luci-interpreter" "circle-verify")
+ REQUIRED_UNITS+=("luci-eval-driver")
REQUIRED_UNITS+=("record-minmax" "circle-quantizer" "rawdata2hdf5")
REQUIRED_UNITS+=("circle-partitioner")
REQUIRED_UNITS+=("one-cmds")
# Tools
REQUIRED_UNITS+=("tflite2circle" "circle2circle" "tflchef" "circlechef")
REQUIRED_UNITS+=("tf2tfliteV2" "luci-interpreter" "circle-verify")
+ REQUIRED_UNITS+=("luci-eval-driver")
REQUIRED_UNITS+=("record-minmax" "circle-quantizer" "rawdata2hdf5")
REQUIRED_UNITS+=("circle-partitioner")
REQUIRED_UNITS+=("one-cmds")
--- /dev/null
+#!/bin/bash
+
+# NOTE purpose of this file is static analysis only
+# new official preset will be added when new programs are ready
+
+PRESET="20210706"
+
+function preset_configure()
+{
+ REQUIRED_UNITS=()
+ # Common Libraries
+ REQUIRED_UNITS+=("angkor" "cwrap" "pepper-str" "pepper-strcast" "pp")
+ REQUIRED_UNITS+=("oops" "pepper-assert" "pepper-csv2vec" "foder" "crew")
+ REQUIRED_UNITS+=("souschef")
+ REQUIRED_UNITS+=("safemain")
+ REQUIRED_UNITS+=("arser")
+ REQUIRED_UNITS+=("vconone")
+ # Hermes Logging Framework
+ REQUIRED_UNITS+=("hermes" "hermes-std")
+ # loco IR and related utilities
+ REQUIRED_UNITS+=("loco" "locop" "locomotiv" "logo-core" "logo")
+ # Flatbuffer I/O
+ REQUIRED_UNITS+=("mio-tflite" "mio-circle")
+ # Circle compiler library (.circle -> .circle)
+ REQUIRED_UNITS+=("luci")
+ # Tools
+ REQUIRED_UNITS+=("tflite2circle" "circle2circle" "tflchef" "circlechef")
+ REQUIRED_UNITS+=("circle-tensordump" "circledump")
+ REQUIRED_UNITS+=("tf2tfliteV2" "luci-interpreter" "circle-verify")
+ REQUIRED_UNITS+=("luci-eval-driver")
+ REQUIRED_UNITS+=("record-minmax" "circle-quantizer" "rawdata2hdf5")
+ REQUIRED_UNITS+=("circle-partitioner")
+ REQUIRED_UNITS+=("one-cmds")
+ REQUIRED_UNITS+=("bcq-tools")
+
+ NPROC=${NPROC:-$(cat /proc/cpuinfo | grep -c processor)}
+
+ # TODO Use "nncc configure" and "nncc build"
+ cmake \
+ -DCMAKE_INSTALL_PREFIX="${NNCC_INSTALL_PREFIX}" \
+ -DCMAKE_BUILD_TYPE=release \
+ -DBUILD_WHITELIST=$(join_by ";" "${REQUIRED_UNITS[@]}") \
+ -DEXTERNALS_BUILD_THREADS=$((NPROC/2)) \
+ ${EXTRA_OPTIONS[@]} \
+ "${NNAS_PROJECT_PATH}/infra/nncc"
+}
+
+function preset_install()
+{
+ install -t "${NNPKG_INSTALL_PREFIX}/bin" -D \
+ "${NNAS_PROJECT_PATH}/tools/nnpackage_tool/model2nnpkg/model2nnpkg.sh"
+
+ # Install tf2nnpkg
+ install -T -m 755 -D "${SCRIPT_PATH}/res/tf2nnpkg.${PRESET}" "${NNAS_INSTALL_PREFIX}/bin/tf2nnpkg"
+}
--- /dev/null
+#!/bin/bash
+
+function preset_configure()
+{
+ REQUIRED_UNITS=()
+ # Common Libraries
+ REQUIRED_UNITS+=("angkor" "cwrap" "pepper-str" "pepper-strcast" "pp")
+ REQUIRED_UNITS+=("oops" "pepper-assert" "pepper-csv2vec" "foder" "crew")
+ REQUIRED_UNITS+=("souschef")
+ REQUIRED_UNITS+=("safemain")
+ REQUIRED_UNITS+=("arser")
+ REQUIRED_UNITS+=("vconone")
+ # Hermes Logging Framework
+ REQUIRED_UNITS+=("hermes" "hermes-std")
+ # loco IR and related utilities
+ REQUIRED_UNITS+=("loco" "locop" "locomotiv" "logo-core" "logo")
+ # Flatbuffer I/O
+ REQUIRED_UNITS+=("mio-tflite" "mio-circle")
+ # Circle compiler library (.circle -> .circle)
+ REQUIRED_UNITS+=("luci")
+ # Tools
+ REQUIRED_UNITS+=("tflite2circle" "circle2circle" "tflchef" "circlechef")
+ REQUIRED_UNITS+=("tf2tfliteV2" "luci-interpreter" "circle-verify")
+ REQUIRED_UNITS+=("luci-eval-driver")
+ REQUIRED_UNITS+=("record-minmax" "circle-quantizer" "rawdata2hdf5")
+ REQUIRED_UNITS+=("circle-partitioner")
+ REQUIRED_UNITS+=("one-cmds")
+ REQUIRED_UNITS+=("bcq-tools")
+
+ NPROC=$(cat /proc/cpuinfo | grep -c processor)
+
+ # TODO Use "nncc configure" and "nncc build"
+ cmake \
+ -G "MSYS Makefiles" \
+ -DUSE_PROTOBUF_LEGACY_IMPORT=ON \
+ -DCMAKE_EXE_LINKER_FLAGS="-Wl,--allow-multiple-definition" \
+ -DCMAKE_SHARED_LINKER_FLAGS="-Wl,--allow-multiple-definition" \
+ -DENABLE_TEST=OFF \
+ -DDOWNLOAD_GTEST=OFF \
+ -DBUILD_GTEST=OFF \
+ -DCMAKE_C_COMPILER=gcc \
+ -DCMAKE_CXX_COMPILER=g++ \
+ -DCMAKE_INSTALL_PREFIX="${NNCC_INSTALL_PREFIX}" \
+ -DCMAKE_BUILD_TYPE=release \
+ -DBUILD_WHITELIST=$(join_by ";" "${REQUIRED_UNITS[@]}") \
+ -DEXTERNALS_BUILD_THREADS=$((NPROC/2)) \
+ ${EXTRA_OPTIONS[@]} \
+ "${NNAS_PROJECT_PATH}/infra/nncc"
+}
+
+function preset_install()
+{
+ # Install libraries to bin/ for Windows release
+ mv ${NNCC_INSTALL_PREFIX}/lib/*.dll ${NNCC_INSTALL_PREFIX}/bin
+ rm -rf ${NNCC_INSTALL_PREFIX}/lib
+
+ install -t "${NNPKG_INSTALL_PREFIX}/bin" -D \
+ "${NNAS_PROJECT_PATH}/tools/nnpackage_tool/model2nnpkg/model2nnpkg.sh"
+
+ # Install tf2nnpkg
+ install -T -m 755 -D "${SCRIPT_PATH}/res/tf2nnpkg.20210706" "${NNAS_INSTALL_PREFIX}/bin/tf2nnpkg"
+
+ # Though you have to install tensorflow to run 'tf2tfliteV2',
+ # tensorflow can't be installed in mingw. First, You can install tensorflow
+ # from Window native CMD(run as administrator) with python virtual environment.
+ # And, you must copy it to "${NNAS_INSTALL_PREFIX}/bin/venv"
+}
INPUT_SHAPES=$(grep ^input ${INFO_FILE} | cut -d "[" -f2 | cut -d "]" -f1 | tr -d ' ' | xargs | tr ' ' ':')
-# Generate BCQ information metadata
-# If model has no BCQ information or invalid information, pb file is not changed.
-"${ROOT}/bin/generate_bcq_metadata" \
---input_path "${GRAPHDEF_FILE}" \
---output_path "${TMPDIR}/${MODEL_NAME}_withmeta.pb" \
---output_arrays "${OUTPUT}"
-
-# Generate BCQ information nodes as output_arrays
-# If model has no BCQ information, output_arrays would be empty.
-"${ROOT}/bin/generate_bcq_output_arrays" \
---input_path "${TMPDIR}/${MODEL_NAME}_withmeta.pb" \
---metadata_path "${TMPDIR}/${MODEL_NAME}_metadata_arrays.txt" \
---output_arrays_path "${TMPDIR}/${MODEL_NAME}_output_arrays.txt"
-
-# generate tflite file
-TF2TFLITE_CONVERT_SCRIPT="python ${ROOT}/bin/tf2tfliteV2.py ${TF_INTERFACE} "
-TF2TFLITE_CONVERT_SCRIPT+="--input_path ${TMPDIR}/${MODEL_NAME}_withmeta.pb "
-TF2TFLITE_CONVERT_SCRIPT+="--input_arrays ${INPUT} "
-TF2TFLITE_CONVERT_SCRIPT+="--output_path ${TMPDIR}/${MODEL_NAME}.tflite "
-TF2TFLITE_CONVERT_SCRIPT+="--output_arrays "
-TF2TFLITE_CONVERT_SCRIPT+="$(cat ${TMPDIR}/${MODEL_NAME}_metadata_arrays.txt)"
-TF2TFLITE_CONVERT_SCRIPT+="${OUTPUT}"
-TF2TFLITE_CONVERT_SCRIPT+="$(cat ${TMPDIR}/${MODEL_NAME}_output_arrays.txt) "
+ONE_IMPORT_BCQ_SCRIPT="${ROOT}/bin/one-import-bcq ${TF_INTERFACE} "
+ONE_IMPORT_BCQ_SCRIPT+="-i ${GRAPHDEF_FILE} "
+ONE_IMPORT_BCQ_SCRIPT+="-o ${TMPDIR}/${MODEL_NAME}.tmp.circle "
+ONE_IMPORT_BCQ_SCRIPT+="-I ${INPUT} "
+ONE_IMPORT_BCQ_SCRIPT+="-O ${OUTPUT} "
if [ ! -z ${INPUT_SHAPES} ]; then
- TF2TFLITE_CONVERT_SCRIPT+="--input_shapes ${INPUT_SHAPES} "
+ ONE_IMPORT_BCQ_SCRIPT+="-s ${INPUT_SHAPES} "
fi
-${TF2TFLITE_CONVERT_SCRIPT}
-
-# convert .tflite to .circle
-"${ROOT}/bin/tflite2circle" "${TMPDIR}/${MODEL_NAME}.tflite" "${TMPDIR}/${MODEL_NAME}.tmp.circle"
+${ONE_IMPORT_BCQ_SCRIPT}
# optimize
"${ROOT}/bin/circle2circle" --O1 "${TMPDIR}/${MODEL_NAME}.tmp.circle" "${TMPDIR}/${MODEL_NAME}.circle"
--- /dev/null
+#!/bin/bash
+
+set -e
+
+ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+
+command_exists() {
+ if [ "$#" -le 0 ]; then
+ return 1
+ fi
+ command -v "$@" > /dev/null 2>&1
+}
+
+usage()
+{
+ echo "Convert TensorFlow model to nnpackage."
+ echo "Usage: tf2nnpkg"
+ echo " --info <path/to/info>"
+ echo " --graphdef <path/to/pb>"
+ echo " -o <path/to/nnpkg/directory>"
+ echo " --v2 (optional) Use TF 2.x interface"
+ exit 255
+}
+
+TF_INTERFACE="--v1"
+
+# Parse command-line arguments
+#
+while [ "$#" -ne 0 ]; do
+ CUR="$1"
+
+ case $CUR in
+ '--help')
+ usage
+ ;;
+ '--info')
+ export INFO_FILE="$2"
+ shift 2
+ ;;
+ '--graphdef')
+ export GRAPHDEF_FILE="$2"
+ shift 2
+ ;;
+ '-o')
+ export OUTPUT_DIR="$2"
+ shift 2
+ ;;
+ '--v2')
+ TF_INTERFACE="--v2"
+ shift
+ ;;
+ *)
+ echo "${CUR}"
+ shift
+ ;;
+ esac
+done
+
+if [ -z ${GRAPHDEF_FILE} ] || [ ! -e ${GRAPHDEF_FILE} ]; then
+ echo "pb is not found. Please check --graphdef is correct."
+ exit 2
+fi
+
+if [ -z ${INFO_FILE} ] || [ ! -e ${INFO_FILE} ]; then
+ echo "info is not found. Please check --info is correct."
+ exit 2
+fi
+
+if [ -z ${OUTPUT_DIR} ]; then
+ echo "output directory is not specifed. Please check -o is correct.."
+ exit 2
+fi
+
+FILE_BASE=$(basename ${GRAPHDEF_FILE})
+MODEL_NAME="${FILE_BASE%.*}"
+TMPDIR=$(mktemp -d)
+trap "{ rm -rf $TMPDIR; }" EXIT
+
+# activate python virtual environment
+VIRTUALENV_LINUX="${ROOT}/bin/venv/bin/activate"
+VIRTUALENV_WINDOWS="${ROOT}/bin/venv/Scripts/activate"
+
+if [ -e ${VIRTUALENV_LINUX} ]; then
+ source ${VIRTUALENV_LINUX}
+elif [ -e ${VIRTUALENV_WINDOWS} ]; then
+ source ${VIRTUALENV_WINDOWS}
+fi
+
+# parse inputs, outputs from info file
+INPUT=$(awk -F, '/^input/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)
+OUTPUT=$(awk -F, '/^output/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)
+
+INPUT_SHAPES=$(grep ^input ${INFO_FILE} | cut -d "[" -f2 | cut -d "]" -f1 | tr -d ' ' | xargs | tr ' ' ':')
+
+ONE_IMPORT_BCQ_SCRIPT="${ROOT}/bin/one-import-bcq ${TF_INTERFACE} "
+ONE_IMPORT_BCQ_SCRIPT+="-i ${GRAPHDEF_FILE} "
+ONE_IMPORT_BCQ_SCRIPT+="-o ${TMPDIR}/${MODEL_NAME}.tmp.circle "
+ONE_IMPORT_BCQ_SCRIPT+="-I ${INPUT} "
+ONE_IMPORT_BCQ_SCRIPT+="-O ${OUTPUT} "
+if [ ! -z ${INPUT_SHAPES} ]; then
+ ONE_IMPORT_BCQ_SCRIPT+="-s ${INPUT_SHAPES} "
+fi
+
+${ONE_IMPORT_BCQ_SCRIPT}
+
+# optimize
+"${ROOT}/bin/circle2circle" --O1 "${TMPDIR}/${MODEL_NAME}.tmp.circle" "${TMPDIR}/${MODEL_NAME}.circle"
+
+"${ROOT}/bin/model2nnpkg.sh" -o "${OUTPUT_DIR}" "${TMPDIR}/${MODEL_NAME}.circle"
#
# STEP 1
# Download latest TCM tool from
-# https://github.sec.samsung.net/RS-TCM/tca-standalone/releases/download/v0.0.8/tca-standalone-0.0.8.jar
+# https://github.sec.samsung.net/RS-TCM/tca-standalone/releases/download/1.0.2/tca-standalone-1.0.2.jar
#
# STEP 2
# Create symbolic link `./src` for source directory to be analyzed which has `.ahub` configuration.
#
# STEP 3
-# run this `build-tcm.sh` script.
+# run this script in `build-tcm.sh [test_target]` format.
+# ex) $ build_tcm.sh # to analyze both NN Runtime and NN Compiler
+# ex) $ build_tcm.sh NN_Runtime # to analyze NN Runtime only
+# ex) $ build_tcm.sh NN_Compiler # to analyze NN Compiler only
#
# See the following link for additional details.
# https://github.sec.samsung.net/RS-TCM/tca-standalone/wiki/Tutorials-CPP-Gtest
echo ${PROJECT_DIR:=${PWD}}
-java -jar $PROJECT_DIR/tca-standalone-0.0.8.jar \
+java -jar $PROJECT_DIR/tca-standalone-1.0.2.jar \
--outdir=$PROJECT_DIR/tcm-output \
--config=$PROJECT_DIR/src/.ahub/tcchecker-tca/config.yaml \
--local=$PROJECT_DIR/src \
--logfile=$PROJECT_DIR/tcm-output/tcm.log \
--debug
+ $@
CURRENT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ROOT_PATH="$CURRENT_PATH/../../"
-# prepare pre-built armcompute library
-# android build requires pre-built armcompute library
-# if [ ! -n "$EXT_ACL_FOLDER" ]; then
-# echo "Please set EXT_ACL_FOLDER to use pre-built armcompute library"
-# exit 1
-# fi
-
-unset EXT_ACL_FOLDER
-
# prepare ndk
if [ ! -n "$NDK_DIR" ]; then
export NDK_DIR=$ROOT_PATH/tools/cross/ndk/r20/ndk
[[ "${BASH_SOURCE[0]}" == "${0}" ]] && echo "Please don't execute ${BASH_SOURCE[0]}, source it" && return
DEBUG_BUILD_ITEMS="angkor;cwrap;pepper-str;pepper-strcast;pp"
-DEBUG_BUILD_ITEMS+=";oops;pepper-assert"
+DEBUG_BUILD_ITEMS+=";oops;pepper-assert;pepper-csv2vec"
DEBUG_BUILD_ITEMS+=";hermes;hermes-std"
DEBUG_BUILD_ITEMS+=";loco;locop;locomotiv;logo-core;logo"
DEBUG_BUILD_ITEMS+=";foder;crew;souschef;arser;vconone"
[[ "${BASH_SOURCE[0]}" != "${0}" ]] && echo "Please don't source ${BASH_SOURCE[0]}, execute it" && return
+unset RELEASE_VERSION
+# TODO need more better argument parsing
+RELEASE_VERSION="$1"
+
CURRENT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ROOT_PATH="$CURRENT_PATH/../../"
./nncc docker-run ./nnas create-package --prefix "${PWD}/${NNCC_INSTALL_PREFIX}" -- "${CONFIG_OPTIONS}"
mkdir -p ${ARCHIVE_PATH}
-tar -zcf ${ARCHIVE_PATH}/nncc-package.tar.gz -C ${NNCC_INSTALL_PREFIX} --exclude test --exclude tflchef* ./
+tar -zcf ${ARCHIVE_PATH}/nncc-package.tar.gz -C ${NNCC_INSTALL_PREFIX} \
+ --exclude test --exclude tflchef* --exclude circle-tensordump --exclude circledump ./
tar -zcf ${ARCHIVE_PATH}/nncc-test-package.tar.gz -C ${NNCC_INSTALL_PREFIX} ./test
+if [ -z ${RELEASE_VERSION} ] || [ ${RELEASE_VERSION} == "nightly" ]; then
+ ./nncc docker-run /bin/bash -c \
+ 'dch -v $($(pwd)/tools/release_tool/onert_version.sh)~$(date "+%y%m%d%H") "nightly release" -D $(lsb_release --short --codename)'
+ ./nncc docker-run dch -r ''
+fi
+
+./nncc docker-run debuild --preserve-env --no-lintian -us -uc \
+ -b --buildinfo-option=-ubuild --changes-option=-ubuild
+
popd > /dev/null
REQUIRED_UNITS=()
# Common Libraries
REQUIRED_UNITS+=("angkor" "cwrap" "pepper-str" "pepper-strcast" "pp")
+REQUIRED_UNITS+=("pepper-csv2vec")
REQUIRED_UNITS+=("oops" "safemain" "foder" "crew" "arser" "vconone")
# Hermes Logging Framework
REQUIRED_UNITS+=("hermes" "hermes-std")
Name: nnfw
Summary: nnfw
-Version: 1.15.0
+Version: 1.17.0
Release: 1
Group: Development
License: Apache-2.0 and MIT and BSD-2-Clause
Source1009: PTHREADPOOL.tar.gz
Source1010: PSIMD.tar.gz
Source1011: FP16.tar.gz
+Source1012: OPENCL_HEADERS.tar.gz
+Source1013: FARMHASH.tar.gz
+Source1014: ABSEIL.tar.gz
+Source1015: oourafft.tar.gz
Source2001: nnfw.pc.in
Source2002: nnfw-plugin.pc.in
%ifarch %{arm} aarch64
# Require python for acl-ex library build pre-process
BuildRequires: python
-BuildRequires: libarmcl-devel >= v20.05
+BuildRequires: libarmcl-devel >= v21.02
%endif
Requires(post): /sbin/ldconfig
tar -xf %{SOURCE1009} -C ./externals
tar -xf %{SOURCE1010} -C ./externals
tar -xf %{SOURCE1011} -C ./externals
+tar -xf %{SOURCE1012} -C ./externals
+tar -xf %{SOURCE1013} -C ./externals
+tar -xf %{SOURCE1014} -C ./externals
+tar -xf %{SOURCE1015} -C ./externals
%build
%ifarch arm armv7l aarch64 x86_64 %ix86
%defattr(-,root,root,-)
%ifarch arm armv7l aarch64 x86_64 %ix86
%{_libdir}/*.so
+%exclude %{_includedir}/CL/*
%endif
%files devel
--- /dev/null
+# This is test for import/export of STRING tensortype
+# interpreter or runtime may fail as Add won't support this
+
+operand {
+ name: "ifm"
+ type: STRING
+ shape { }
+}
+operand {
+ name: "suffix"
+ type: STRING
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "Hello"
+ }
+}
+operand {
+ name: "ofm"
+ type: STRING
+ shape { }
+}
+operation {
+ type: "Add"
+ input: "ifm"
+ input: "suffix"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+# This is test for import/export of STRING tensortype
+# interpreter or runtime may fail as Add won't support this
+
+operand {
+ name: "ifm"
+ type: STRING
+ shape { }
+}
+operand {
+ name: "suffix"
+ type: STRING
+ shape { dim: 2 }
+ filler {
+ tag: "explicit"
+ arg: "Hello"
+ arg: "World"
+ }
+}
+operand {
+ name: "ofm"
+ type: STRING
+ shape { }
+}
+operation {
+ type: "Add"
+ input: "ifm"
+ input: "suffix"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+}
+input: "ifm"
+output: "ofm"
shape { dim: 1 dim: 9 dim: 9 dim: 1 }
}
operation {
- type: "MaxPoolWithArgMax"
+ type: "MaxPoolWithArgmax"
input: "ifm"
output: "ofm"
output: "argmax"
--- /dev/null
+# To check if MaxPoolWithArgmax is transformed to MaxPool, ArgMax and index computation network
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "ARG_MAX_COUNT" $(op_count ARG_MAX) '=' 1
+RULE "ARG_MAX_COUNT" $(op_count MAX_POOL_2D) '=' 1
+RULE "CONV_COUNT" $(op_count CONV_2D) '=' 1
+RULE "SPLIT_COUNT" $(op_count SPLIT) '=' 1
+RULE "RESHAPE_COUNT" $(op_count RESHAPE) '=' 1
+RULE "CAST_COUNT" $(op_count CAST) '=' 2
+RULE "ADD_COUNT" $(op_count ADD) '=' 3
+RULE "MUL_COUNT" $(op_count MUL) '=' 5
+RULE "FLOOR_COUNT" $(op_count FLOOR) '=' 1
+RULE "NEG_COUNT" $(op_count NEG) '=' 1
+RULE "CONCATENATION_COUNT" $(op_count CONCATENATION) '=' 1
+RULE "PADV2_COUNT" $(op_count PADV2) '=' 1
+RULE "CUSTOM_COUNT" $(op_count 'CUSTOM(MaxPoolWithArgmax)') '=' 0
shape { dim: 1 dim: 9 dim: 9 dim: 1 }
}
operation {
- type: "MaxPoolWithArgMax"
+ type: "MaxPoolWithArgmax"
input: "ifm"
output: "ofm"
output: "argmax"
--- /dev/null
+# To check if MaxPoolWithArgmax is transformed to MaxPool, ArgMax and index computation network
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "ARG_MAX_COUNT" $(op_count ARG_MAX) '=' 1
+RULE "ARG_MAX_COUNT" $(op_count MAX_POOL_2D) '=' 1
+RULE "CONV_COUNT" $(op_count CONV_2D) '=' 1
+RULE "SPLIT_COUNT" $(op_count SPLIT) '=' 1
+RULE "RESHAPE_COUNT" $(op_count RESHAPE) '=' 1
+RULE "CAST_COUNT" $(op_count CAST) '=' 2
+RULE "ADD_COUNT" $(op_count ADD) '=' 3
+RULE "MUL_COUNT" $(op_count MUL) '=' 5
+RULE "FLOOR_COUNT" $(op_count FLOOR) '=' 1
+RULE "NEG_COUNT" $(op_count NEG) '=' 1
+RULE "CONCATENATION_COUNT" $(op_count CONCATENATION) '=' 1
+RULE "PADV2_COUNT" $(op_count PADV2) '=' 1
+RULE "CUSTOM_COUNT" $(op_count 'CUSTOM(MaxPoolWithArgmax)') '=' 0
operand {
name: "ifm"
type: FLOAT32
- shape { dim: 1 dim: 18 dim: 18 dim: 1 }
+ shape { dim: 1 dim: 18 dim: 18 dim: 2 }
}
operand {
name: "ofm"
type: FLOAT32
- shape { dim: 1 dim: 8 dim: 8 dim: 1 }
+ shape { dim: 1 dim: 8 dim: 8 dim: 2 }
}
operand {
name: "argmax"
type: INT64
- shape { dim: 1 dim: 8 dim: 8 dim: 1 }
+ shape { dim: 1 dim: 8 dim: 8 dim: 2 }
}
operation {
- type: "MaxPoolWithArgMax"
+ type: "MaxPoolWithArgmax"
input: "ifm"
output: "ofm"
output: "argmax"
--- /dev/null
+# To check if MaxPoolWithArgmax is transformed to MaxPool, ArgMax and index computation network
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "ARG_MAX_COUNT" $(op_count ARG_MAX) '=' 2
+RULE "ARG_MAX_COUNT" $(op_count MAX_POOL_2D) '=' 1
+RULE "CONV_COUNT" $(op_count CONV_2D) '=' 2
+RULE "SPLIT_COUNT" $(op_count SPLIT) '=' 1
+RULE "RESHAPE_COUNT" $(op_count RESHAPE) '=' 2
+RULE "CAST_COUNT" $(op_count CAST) '=' 3
+RULE "ADD_COUNT" $(op_count ADD) '=' 7
+RULE "MUL_COUNT" $(op_count MUL) '=' 8
+RULE "FLOOR_COUNT" $(op_count FLOOR) '=' 2
+RULE "NEG_COUNT" $(op_count NEG) '=' 2
+RULE "CONCATENATION_COUNT" $(op_count CONCATENATION) '=' 1
+RULE "CUSTOM_COUNT" $(op_count 'CUSTOM(MaxPoolWithArgmax)') '=' 0
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "filter"
+ type: FLOAT32
+ shape { dim: 8 dim: 1 dim: 1 dim: 3 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "bias"
+ type: FLOAT32
+ shape { dim: 8 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "conv"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 8 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 8 }
+}
+operation {
+ type: "Conv2D"
+ conv2d_options {
+ padding: VALID
+ stride_w: 1
+ stride_h: 1
+ }
+ input: "ifm"
+ input: "filter"
+ input: "bias"
+ output: "conv"
+}
+operation {
+ type: "FakeQuant"
+ fakequant_options {
+ min: 0.0
+ max: 1.0
+ num_bits: 8
+ narrow_range: false
+ }
+ input: "conv"
+ output: "ofm"
+}
+
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if FakeQuant is removed by remove_fakequant
+#
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 1
+RULE "NO_FAKE_QUANT" $(op_count FAKE_QUANT) '=' 0
--- /dev/null
+operand {
+ name: "Placeholder"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Const_4"
+ type: FLOAT32
+ shape { }
+ filler { tag: "explicit" arg: "6" }
+}
+operand {
+ name: "Conv2D_1"
+ type: FLOAT32
+ shape { dim: 3 dim: 3 dim: 3 dim: 3 }
+ filler { tag: "gaussian" arg: "0.0" arg: "0.1" }
+}
+operand {
+ name: "Conv2D_2"
+ type: FLOAT32
+ shape { dim: 3 }
+ filler { tag: "gaussian" arg: "0.0" arg: "0.1" }
+}
+operand {
+ name: "Conv2D_21"
+ type: FLOAT32
+ shape { dim: 3 dim: 3 dim: 3 dim: 3 }
+ filler { tag: "gaussian" arg: "0.0" arg: "0.1" }
+}
+operand {
+ name: "Conv2D_11"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Minimum"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Relu"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Conv2D_22"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Minimum_1"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "Relu_1"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operation {
+ type: "Conv2D"
+ input: "Placeholder"
+ input: "Conv2D_1"
+ input: "Conv2D_2"
+ output: "Conv2D_11"
+ conv2d_options {
+ padding: SAME
+ stride_w: 1
+ stride_h: 1
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Minimum"
+ input: "Conv2D_11"
+ input: "Const_4"
+ output: "Minimum"
+}
+operation {
+ type: "ReLU"
+ input: "Minimum"
+ output: "Relu"
+}
+operation {
+ type: "Conv2D"
+ input: "Relu"
+ input: "Conv2D_21"
+ input: "Conv2D_2"
+ output: "Conv2D_22"
+ conv2d_options {
+ padding: SAME
+ stride_w: 1
+ stride_h: 1
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Minimum"
+ input: "Conv2D_22"
+ input: "Const_4"
+ output: "Minimum_1"
+}
+operation {
+ type: "ReLU"
+ input: "Minimum_1"
+ output: "Relu_1"
+}
+input: "Placeholder"
+output: "Relu_1"
--- /dev/null
+# To check if Minumum and ReLU are converte to Relu6 op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 2
+RULE "RELU6_EXIST" $(op_count RELU6) '=' 2
+RULE "MIN_NOT_EXIST" $(op_count MINUMUM) '=' 0
+RULE "RELU_NOT_EXIST" $(op_count RELU) '=' 0
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 3 }
+}
+operand {
+ name: "filter"
+ type: FLOAT32
+ shape { dim: 8 dim: 1 dim: 1 dim: 3 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "bias"
+ type: FLOAT32
+ shape { dim: 8 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "conv"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 8 }
+}
+operand {
+ name: "quantize"
+ type: UINT8
+ shape { dim: 1 dim: 16 dim: 16 dim: 8 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 1 dim: 16 dim: 16 dim: 8 }
+}
+operation {
+ type: "Conv2D"
+ conv2d_options {
+ padding: VALID
+ stride_w: 1
+ stride_h: 1
+ }
+ input: "ifm"
+ input: "filter"
+ input: "bias"
+ output: "conv"
+}
+operation {
+ type: "Quantize"
+ input: "conv"
+ output: "quantize"
+}
+operation {
+ type: "Dequantize"
+ input: "quantize"
+ output: "ofm"
+}
+
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if FakeQuant is removed by remove_quantdequant
+#
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 1
+RULE "NO_QUANTIZE" $(op_count QUANTIZE) '=' 0
+RULE "NO_DEQUANTIZE" $(op_count DEQUANTIZE) '=' 0
-# To check if custom op BatchMatMulV2 is converted to circle builtin op
+# To check if this network is converted to circle InstanceNorm op
RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
-# To check if custom op InstanceNorm is converted to circle builtin op
+# To check if this network is converted to circle InstanceNorm op
RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
-# To check if custom op BatchMatMulV2 is converted to circle builtin op
+# To check if this network is converted to circle InstanceNorm op
RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
--- /dev/null
+# generated using tflchef-reverse
+# with tflite from https://github.com/Samsung/ONE/issues/7067#issuecomment-867203553
+
+operand {
+ name: "input_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/Reshape"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "1"
+ arg: "1"
+ }
+}
+operand {
+ name: "instance_normalization/Reshape_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "0"
+ arg: "0"
+ }
+}
+operand {
+ name: "instance_normalization/add"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/add/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "1e-09"
+ }
+}
+operand {
+ name: "instance_normalization/add_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/mul"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/Sqrt"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Square"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/truediv"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/Mean/reduction_indices"
+ output: "instance_normalization/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/Mean"
+ output: "instance_normalization/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean"
+ output: "instance_normalization/reduce_std/reduce_variance/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Square"
+ input: "instance_normalization/reduce_std/reduce_variance/sub"
+ output: "instance_normalization/reduce_std/reduce_variance/Square"
+}
+operation {
+ type: "Mean"
+ input: "instance_normalization/reduce_std/reduce_variance/Square"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sqrt"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ output: "instance_normalization/reduce_std/Sqrt"
+}
+operation {
+ type: "Add"
+ input: "instance_normalization/reduce_std/Sqrt"
+ input: "instance_normalization/add/y"
+ output: "instance_normalization/add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Div"
+ input: "instance_normalization/sub"
+ input: "instance_normalization/add"
+ output: "instance_normalization/truediv"
+ div_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Mul"
+ input: "instance_normalization/truediv"
+ input: "instance_normalization/Reshape"
+ output: "instance_normalization/mul"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "instance_normalization/mul"
+ input: "instance_normalization/Reshape_1"
+ output: "instance_normalization/add_1"
+ add_options {
+ activation: NONE
+ }
+}
+input: "input_1"
+output: "instance_normalization/add_1"
--- /dev/null
+# To check if this network is converted to circle InstanceNorm op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_MUL" $(op_count MUL) '=' 0
+RULE "NO_SQRT" $(op_count SQRT) '=' 0
+RULE "NO_DIV" $(op_count DIV) '=' 0
+RULE "NO_SUB" $(op_count SUB) '=' 0
+RULE "NO_SQUARE" $(op_count SQUARE) '=' 0
+RULE "NO_MEAN" $(op_count MEAN) '=' 0
--- /dev/null
+# generated using tflchef-reverse
+# with tflite from https://github.com/Samsung/ONE/issues/7067#issuecomment-867203553
+
+operand {
+ name: "input_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/add"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/add/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "1e-09"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/Sqrt"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Square"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/truediv"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/Mean/reduction_indices"
+ output: "instance_normalization/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/Mean"
+ output: "instance_normalization/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean"
+ output: "instance_normalization/reduce_std/reduce_variance/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Square"
+ input: "instance_normalization/reduce_std/reduce_variance/sub"
+ output: "instance_normalization/reduce_std/reduce_variance/Square"
+}
+operation {
+ type: "Mean"
+ input: "instance_normalization/reduce_std/reduce_variance/Square"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sqrt"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ output: "instance_normalization/reduce_std/Sqrt"
+}
+operation {
+ type: "Add"
+ input: "instance_normalization/reduce_std/Sqrt"
+ input: "instance_normalization/add/y"
+ output: "instance_normalization/add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Div"
+ input: "instance_normalization/sub"
+ input: "instance_normalization/add"
+ output: "instance_normalization/truediv"
+ div_options {
+ activation: NONE
+ }
+}
+input: "input_1"
+output: "instance_normalization/truediv"
--- /dev/null
+# To check if this network is converted to circle InstanceNorm op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_SQRT" $(op_count SQRT) '=' 0
+RULE "NO_DIV" $(op_count DIV) '=' 0
+RULE "NO_SUB" $(op_count SUB) '=' 0
+RULE "NO_SQUARE" $(op_count SQUARE) '=' 0
+RULE "NO_MEAN" $(op_count MEAN) '=' 0
--- /dev/null
+# InstanceNorm network with one element for gamma, beta
+
+operand {
+ name: "input_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/Reshape"
+ type: FLOAT32
+ shape {
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ }
+}
+operand {
+ name: "instance_normalization/Reshape_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ }
+}
+operand {
+ name: "instance_normalization/add"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/add/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "1e-09"
+ }
+}
+operand {
+ name: "instance_normalization/add_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/mul"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/Sqrt"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/Square"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/reduce_std/reduce_variance/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operand {
+ name: "instance_normalization/truediv"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 3
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/Mean/reduction_indices"
+ output: "instance_normalization/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Mean"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/Mean"
+ output: "instance_normalization/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "input_1"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean"
+ output: "instance_normalization/reduce_std/reduce_variance/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Square"
+ input: "instance_normalization/reduce_std/reduce_variance/sub"
+ output: "instance_normalization/reduce_std/reduce_variance/Square"
+}
+operation {
+ type: "Mean"
+ input: "instance_normalization/reduce_std/reduce_variance/Square"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1/reduction_indices"
+ output: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Sqrt"
+ input: "instance_normalization/reduce_std/reduce_variance/Mean_1"
+ output: "instance_normalization/reduce_std/Sqrt"
+}
+operation {
+ type: "Add"
+ input: "instance_normalization/reduce_std/Sqrt"
+ input: "instance_normalization/add/y"
+ output: "instance_normalization/add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Div"
+ input: "instance_normalization/sub"
+ input: "instance_normalization/add"
+ output: "instance_normalization/truediv"
+ div_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Mul"
+ input: "instance_normalization/truediv"
+ input: "instance_normalization/Reshape"
+ output: "instance_normalization/mul"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "instance_normalization/mul"
+ input: "instance_normalization/Reshape_1"
+ output: "instance_normalization/add_1"
+ add_options {
+ activation: NONE
+ }
+}
+input: "input_1"
+output: "instance_normalization/add_1"
--- /dev/null
+# To check if this network is converted to circle InstanceNorm op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_MUL" $(op_count MUL) '=' 0
+RULE "NO_SQRT" $(op_count SQRT) '=' 0
+RULE "NO_DIV" $(op_count DIV) '=' 0
+RULE "NO_SUB" $(op_count SUB) '=' 0
+RULE "NO_SQUARE" $(op_count SQUARE) '=' 0
+RULE "NO_MEAN" $(op_count MEAN) '=' 0
--- /dev/null
+#
+# This was generated from https://github.com/Samsung/ONE/issues/7032#issuecomment-862238083
+# And some modification
+#
+
+operand {
+ name: "Hole"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/beta"
+ type: FLOAT32
+ shape {
+ dim: 32
+ }
+ filler {
+ tag: "constant"
+ arg: "0"
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/add/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "1e-06"
+ }
+}
+operand {
+ name: "InstanceNorm/moments/variance/reduction_indices"
+ type: INT32
+ shape {
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "InstanceNorm/moments/mean"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/moments/SquaredDifference"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/moments/variance"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/add"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/Rsqrt"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/mul_1"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/mul_2"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/sub"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operand {
+ name: "InstanceNorm/instancenorm/add_1"
+ type: FLOAT32
+ shape {
+ dim: 1 dim: 1 dim: 1 dim: 32
+ }
+}
+operation {
+ type: "Mean"
+ input: "Hole"
+ input: "InstanceNorm/moments/variance/reduction_indices"
+ output: "InstanceNorm/moments/mean"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "SquaredDifference"
+ input: "Hole"
+ input: "InstanceNorm/moments/mean"
+ output: "InstanceNorm/moments/SquaredDifference"
+}
+operation {
+ type: "Mean"
+ input: "InstanceNorm/moments/SquaredDifference"
+ input: "InstanceNorm/moments/variance/reduction_indices"
+ output: "InstanceNorm/moments/variance"
+ mean_options {
+ keep_dims: true
+ }
+}
+operation {
+ type: "Add"
+ input: "InstanceNorm/moments/variance"
+ input: "InstanceNorm/instancenorm/add/y"
+ output: "InstanceNorm/instancenorm/add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Rsqrt"
+ input: "InstanceNorm/instancenorm/add"
+ output: "InstanceNorm/instancenorm/Rsqrt"
+}
+operation {
+ type: "Mul"
+ input: "Hole"
+ input: "InstanceNorm/instancenorm/Rsqrt"
+ output: "InstanceNorm/instancenorm/mul_1"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Mul"
+ input: "InstanceNorm/moments/mean"
+ input: "InstanceNorm/instancenorm/Rsqrt"
+ output: "InstanceNorm/instancenorm/mul_2"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "InstanceNorm/beta"
+ input: "InstanceNorm/instancenorm/mul_2"
+ output: "InstanceNorm/instancenorm/sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "InstanceNorm/instancenorm/mul_1"
+ input: "InstanceNorm/instancenorm/sub"
+ output: "InstanceNorm/instancenorm/add_1"
+ add_options {
+ activation: NONE
+ }
+}
+input: "Hole"
+output: "InstanceNorm/instancenorm/add_1"
--- /dev/null
+# To check if this network is converted to circle InstanceNorm op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "INSTANCE_NORM_EXIST" $(op_count INSTANCE_NORM) '=' 1
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_MUL" $(op_count MUL) '=' 0
+RULE "NO_POW" $(op_count POW) '=' 0
+RULE "NO_DIV" $(op_count DIV) '=' 0
+RULE "NO_SQUARED_DIFF" $(op_count SQUARED_DIFFERENCE) '=' 0
+RULE "NO_MEAN" $(op_count MEAN) '=' 0
+RULE "NO_RSQRT" $(op_count RSQRT) '=' 0
+RULE "NO_SUB" $(op_count SUB) '=' 0
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 3 dim: 8 dim: 8 dim: 4 }
+}
+operand {
+ name: "inner"
+ type: FLOAT32
+ shape { dim: 3 dim: 4 }
+}
+operand {
+ name: "reduction_indices1"
+ type: INT32
+ shape { dim: 2 }
+ filler { tag: "explicit" arg: "1" arg: "2" }
+}
+operand {
+ name: "reduction_indices2"
+ type: INT32
+ shape { dim: 1 }
+ filler { tag: "explicit" arg: "1"}
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 3 }
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: false
+ }
+ input: "ifm"
+ input: "reduction_indices1"
+ output: "inner"
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: false
+ }
+ input: "inner"
+ input: "reduction_indices2"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if Maximum and Minimum is fused to Relu6.
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "MEAN_SINGLE" $(op_count MEAN) '=' 1
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 3 dim: 8 dim: 8 dim: 4 }
+}
+operand {
+ name: "inner"
+ type: FLOAT32
+ shape { dim: 3 dim: 8 dim: 1 dim: 4 }
+}
+operand {
+ name: "reduction_indices1"
+ type: INT32
+ shape { dim: 1 }
+ filler { tag: "explicit" arg: "2" }
+}
+operand {
+ name: "reduction_indices2"
+ type: INT32
+ shape { dim: 1 }
+ filler { tag: "explicit" arg: "1"}
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 3 dim: 1 dim: 1 dim: 4}
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: true
+ }
+ input: "ifm"
+ input: "reduction_indices1"
+ output: "inner"
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: true
+ }
+ input: "inner"
+ input: "reduction_indices2"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if Maximum and Minimum is fused to Relu6.
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "MEAN_SINGLE" $(op_count MEAN) '=' 1
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 3 dim: 8 dim: 8 dim: 4 }
+}
+operand {
+ name: "inner1"
+ type: FLOAT32
+ shape { dim: 3 dim: 8 dim: 1 dim: 4 }
+}
+operand {
+ name: "inner2"
+ type: FLOAT32
+ shape { dim: 3 dim: 1 dim: 4 dim: 8 }
+}
+operand {
+ name: "reduction_indices1"
+ type: INT32
+ shape { dim: 1 }
+ filler { tag: "explicit" arg: "2" }
+}
+operand {
+ name: "reduction_indices2"
+ type: INT32
+ shape { dim: 1 }
+ filler { tag: "explicit" arg: "3"}
+}
+operand {
+ name: "perm"
+ type: INT32
+ shape { dim: 4 }
+ filler { tag: "explicit" arg: "0" arg: "2" arg: "3" arg: "1" }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 3 dim: 1 dim: 4 dim: 1 }
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: true
+ }
+ input: "ifm"
+ input: "reduction_indices1"
+ output: "inner1"
+}
+operation {
+ type: "Transpose"
+ transpose_options {
+ }
+ input: "inner1"
+ input: "perm"
+ output: "inner2"
+}
+operation {
+ type: "Mean"
+ mean_options {
+ keep_dims: true
+ }
+ input: "inner2"
+ input: "reduction_indices2"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if Maximum and Minimum is fused to Relu6.
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "MEAN_SINGLE" $(op_count MEAN) '=' 1
--- /dev/null
+operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "ifm3"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "ifm4"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "some/node/add1;and/another"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "some/node/add2;and/another"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operation {
+ type: "Add"
+ add_options {
+ activation: NONE
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "some/node/add1;and/another"
+}
+operation {
+ type: "Add"
+ add_options {
+ activation: NONE
+ }
+ input: "some/node/add1;and/another"
+ input: "ifm3"
+ output: "some/node/add2;and/another"
+}
+operation {
+ type: "Sub"
+ sub_options {
+ activation: NONE
+ }
+ input: "some/node/add2;and/another"
+ input: "ifm4"
+ output: "ofm"
+}
+input: "ifm1"
+input: "ifm2"
+input: "ifm3"
+input: "ifm4"
+output: "ofm"
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "const"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "add1"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "add2"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 1 dim: 3 dim: 3 dim: 2 }
+}
+operation {
+ type: "Add"
+ add_options {
+ activation: NONE
+ }
+ input: "ifm"
+ input: "const"
+ output: "add1"
+}
+operation {
+ type: "Add"
+ add_options {
+ activation: NONE
+ }
+ input: "add1"
+ input: "const"
+ output: "add2"
+}
+operation {
+ type: "Sub"
+ sub_options {
+ activation: NONE
+ }
+ input: "add2"
+ input: "const"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+version: 1
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_ELSE"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "Mul"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ mul_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_THEN"
+}
+
+operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "cond"
+ type: BOOL
+ shape { }
+}
+operand {
+ name: "add"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "sub"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm2"
+ output: "add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "ifm1"
+ input: "ifm2"
+ output: "sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "If"
+ input: "cond"
+ input: "add"
+ input: "sub"
+ output: "ofm"
+ if_options {
+ then_subgraph_index: 2
+ else_subgraph_index: 1
+ }
+}
+input: "cond"
+input: "ifm1"
+input: "ifm2"
+output: "ofm"
+name: "Main"
--- /dev/null
+version: 1
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_THEN_THEN"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "Mul"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ mul_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_THEN_ELSE"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_ELSE"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operand {
+ name: "cond"
+ type: BOOL
+ shape { dim: 1 }
+ filler {
+ tag: "explicit"
+ arg: "T"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+ }
+ operation {
+ type: "If"
+ input: "cond"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ if_options {
+ then_subgraph_index: 1
+ else_subgraph_index: 2
+ }
+ }
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ name: "IF_THEN"
+}
+
+operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "cond"
+ type: BOOL
+ shape { dim: 1 }
+ filler {
+ tag: "explicit"
+ arg: "T"
+ }
+}
+operand {
+ name: "add"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "sub"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 2 dim: 3 }
+}
+operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm2"
+ output: "add"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Sub"
+ input: "ifm1"
+ input: "ifm2"
+ output: "sub"
+ sub_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "If"
+ input: "cond"
+ input: "add"
+ input: "sub"
+ output: "ofm"
+ if_options {
+ then_subgraph_index: 4
+ else_subgraph_index: 3
+ }
+}
+input: "ifm1"
+input: "ifm2"
+output: "ofm"
+name: "Main"
--- /dev/null
+test.readme of Part_While_000
+
+MAXIMUM and MINIMUM Op exist to make random input to 0
+to make this model loop from 0 to 10.
--- /dev/null
+version: 1
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim:1 }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { dim:1 }
+ filler {
+ tag: "explicit"
+ arg: "10"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: BOOL
+ shape { dim:1 }
+ }
+ operation {
+ type: "Less"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ }
+ input: "ifm1"
+ output: "ofm"
+ name: "WHILE_COND"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim:1 }
+ }
+ operand {
+ name: "ifm3"
+ type: FLOAT32
+ shape { dim:1 }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim:1 }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm3"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ output: "ofm"
+ name: "WHILE_BODY"
+}
+
+operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { dim:1 }
+}
+operand {
+ name: "zero"
+ type: FLOAT32
+ shape { dim:1 }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ }
+}
+operand {
+ name: "min"
+ type: FLOAT32
+ shape { dim:1 }
+}
+operand {
+ name: "max"
+ type: FLOAT32
+ shape { dim:1 }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim:1 }
+}
+operation {
+ type: "Minimum"
+ maximum_options {
+ }
+ input: "ifm1"
+ input: "zero"
+ output: "min"
+}
+operation {
+ type: "Maximum"
+ maximum_options {
+ }
+ input: "min"
+ input: "zero"
+ output: "max"
+}
+operation {
+ type: "While"
+ input: "max"
+ output: "ofm"
+ while_options {
+ body_subgraph_index: 2
+ cond_subgraph_index: 1
+ }
+}
+input: "ifm1"
+output: "ofm"
+name: "Main"
--- /dev/null
+# To check if this network is converted to circle InstanceNorm op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "WHILE_EXIST" $(op_count WHILE) '=' 1
--- /dev/null
+test.readme of Part_While_001
+
+This has WHILE Op inside WHILE_BODY subgraph.
+MAXIMUM and MINIMUM Op exist to make random input to 0
+to make this model loop from 0 to 10.
--- /dev/null
+version: 1
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { }
+ }
+ operand {
+ name: "ifm2"
+ type: FLOAT32
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "10"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: BOOL
+ shape { }
+ }
+ operation {
+ type: "Less"
+ input: "ifm1"
+ input: "ifm2"
+ output: "ofm"
+ }
+ input: "ifm1"
+ output: "ofm"
+ name: "WHILE_WHILE_COND"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { }
+ }
+ operand {
+ name: "ifm3"
+ type: FLOAT32
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm3"
+ output: "ofm"
+ add_options {
+ activation: NONE
+ }
+ }
+ input: "ifm1"
+ output: "ofm"
+ name: "WHILE_WHILE_BODY"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { }
+ }
+ operand {
+ name: "ifm3"
+ type: FLOAT32
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "10"
+ }
+ }
+ operand {
+ name: "ofm"
+ type: BOOL
+ shape { }
+ }
+ operation {
+ type: "Less"
+ input: "ifm1"
+ input: "ifm3"
+ output: "ofm"
+ }
+ input: "ifm1"
+ output: "ofm"
+ name: "WHILE_COND"
+}
+
+graph {
+ operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { }
+ }
+ operand {
+ name: "ifm3"
+ type: FLOAT32
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ }
+ }
+ operand {
+ name: "add"
+ type: FLOAT32
+ shape { }
+ }
+ operand {
+ name: "ofm1"
+ type: FLOAT32
+ shape { }
+ }
+ operation {
+ type: "Add"
+ input: "ifm1"
+ input: "ifm3"
+ output: "add"
+ add_options {
+ activation: NONE
+ }
+ }
+ operation {
+ type: "While"
+ input: "add"
+ output: "ofm1"
+ while_options {
+ cond_subgraph_index: 1
+ body_subgraph_index: 2
+ }
+ }
+ input: "ifm1"
+ output: "ofm1"
+ name: "WHILE_BODY"
+}
+
+operand {
+ name: "ifm1"
+ type: FLOAT32
+ shape { }
+}
+operand {
+ name: "zero"
+ type: FLOAT32
+ shape { }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ }
+}
+operand {
+ name: "min"
+ type: FLOAT32
+ shape { }
+}
+operand {
+ name: "max"
+ type: FLOAT32
+ shape { }
+}
+operand {
+ name: "ofm1"
+ type: FLOAT32
+ shape { }
+}
+operation {
+ type: "Minimum"
+ maximum_options {
+ }
+ input: "ifm1"
+ input: "zero"
+ output: "min"
+}
+operation {
+ type: "Maximum"
+ maximum_options {
+ }
+ input: "min"
+ input: "zero"
+ output: "max"
+}
+operation {
+ type: "While"
+ input: "max"
+ output: "ofm1"
+ while_options {
+ cond_subgraph_index: 3
+ body_subgraph_index: 4
+ }
+}
+input: "ifm1"
+output: "ofm1"
+name: "Main"
--- /dev/null
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 4 }
+}
+operand {
+ name: "ofm"
+ type: UINT8
+ shape { dim: 4 }
+ quant { min: 0 max: 255 scale: 1.0 zero_point: 0 }
+}
+operation {
+ type: "Quantize"
+ input: "ifm"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+operand {
+ name: "Const_8"
+ type: INT32
+ shape {
+ dim: 4
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "0"
+ arg: "0"
+ arg: "0"
+ arg: "1"
+ arg: "1"
+ arg: "1"
+ arg: "1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_8/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "3"
+ arg: "1"
+ arg: "2"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_9/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "2"
+ arg: "3"
+ arg: "1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm/mul"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0.00498116"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0.0332279"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 3
+ dim: 3
+ dim: 16
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_2"
+ type: FLOAT32
+ shape {
+ dim: 8
+ dim: 3
+ dim: 3
+ dim: 1
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_1_weight"
+ type: FLOAT32
+ shape {
+ dim: 1
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_2_weight"
+ type: FLOAT32
+ shape {
+ dim: 8
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_5;PartitionedCall/transpose_5"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "1"
+ arg: "128"
+ arg: "128"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_7"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "1"
+ arg: "130"
+ arg: "130"
+ arg: "1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Pad_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 130
+ dim: 130
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_4"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 130
+ dim: 130
+ dim: 16
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_1_out"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 128
+ dim: 128
+ dim: 1
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_51"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 128
+ dim: 128
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm/mul_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 128
+ dim: 128
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm/add_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 128
+ dim: 128
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Pad_2"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 1
+ dim: 130
+ dim: 130
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_71"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 130
+ dim: 130
+ dim: 1
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_2_out"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 128
+ dim: 128
+ dim: 8
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_8"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 8
+ dim: 128
+ dim: 128
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operation {
+ type: "Transpose"
+ input: "Pad_1"
+ input: "transpose_9/perm"
+ output: "transpose_4"
+}
+operation {
+ type: "Conv2D"
+ input: "transpose_4"
+ input: "convolution_1"
+ input: "convolution_1_weight"
+ output: "convolution_1_out"
+ conv2d_options {
+ padding: VALID
+ stride_w: 1
+ stride_h: 1
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Reshape"
+ input: "convolution_1_out"
+ input: "transpose_5;PartitionedCall/transpose_5"
+ output: "transpose_51"
+}
+operation {
+ type: "Mul"
+ input: "transpose_51"
+ input: "batchnorm/mul"
+ output: "batchnorm/mul_1"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "batchnorm/mul_1"
+ input: "batchnorm/sub"
+ output: "batchnorm/add_1"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Pad"
+ input: "batchnorm/add_1"
+ input: "Const_8"
+ output: "Pad_2"
+}
+operation {
+ type: "Reshape"
+ input: "Pad_2"
+ input: "transpose_7"
+ output: "transpose_71"
+}
+operation {
+ type: "Conv2D"
+ input: "transpose_71"
+ input: "convolution_2"
+ input: "convolution_2_weight"
+ output: "convolution_2_out"
+ conv2d_options {
+ padding: VALID
+ stride_w: 1
+ stride_h: 1
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Transpose"
+ input: "convolution_2_out"
+ input: "transpose_8/perm"
+ output: "transpose_8"
+}
+input: "Pad_1"
+output: "transpose_8"
--- /dev/null
+# To check ONNX conversion is OK
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 2
+RULE "NO_MUL" $(op_count MUL) '=' 0
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_RESHAPE" $(op_count RESHAPE) '=' 0
+RULE "NO_TRANSPOSE" $(op_count TRANSPOSE) '=' 0
--- /dev/null
+operand {
+ name: "Mean_4/reduction_indices"
+ type: INT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "3"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Mean_5/reduction_indices"
+ type: INT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "2"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_73/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "2"
+ arg: "3"
+ arg: "1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_8/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "3"
+ arg: "1"
+ arg: "2"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm_24/mul"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "1.0"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm_24/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0.0"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_12"
+ type: FLOAT32
+ shape {
+ dim: 256
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_121"
+ type: FLOAT32
+ shape {
+ dim: 256
+ dim: 1
+ dim: 1
+ dim: 256
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Relu_23"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 5
+ dim: 5
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_73"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 5
+ dim: 5
+ dim: 256
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "convolution_122"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 5
+ dim: 5
+ dim: 256
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "transpose_74"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 5
+ dim: 5
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "batchnorm_24/mul_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 5
+ dim: 5
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Relu_24"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 5
+ dim: 5
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Mean_4"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ dim: 5
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operand {
+ name: "Mean_5"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 256
+ }
+ quant {
+ quantized_dimension: 0
+ }
+ is_variable: false
+}
+operation {
+ type: "Transpose"
+ input: "Relu_23"
+ input: "transpose_73/perm"
+ output: "transpose_73"
+}
+operation {
+ type: "Conv2D"
+ input: "transpose_73"
+ input: "convolution_121"
+ input: "convolution_12"
+ output: "convolution_122"
+ conv2d_options {
+ padding: VALID
+ stride_w: 1
+ stride_h: 1
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Transpose"
+ input: "convolution_122"
+ input: "transpose_8/perm"
+ output: "transpose_74"
+}
+operation {
+ type: "Mul"
+ input: "transpose_74"
+ input: "batchnorm_24/mul"
+ output: "batchnorm_24/mul_1"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "batchnorm_24/mul_1"
+ input: "batchnorm_24/sub"
+ output: "Relu_24"
+ add_options {
+ activation: RELU
+ }
+}
+operation {
+ type: "Mean"
+ input: "Relu_24"
+ input: "Mean_4/reduction_indices"
+ output: "Mean_4"
+ mean_options {
+ keep_dims: false
+ }
+}
+operation {
+ type: "Mean"
+ input: "Mean_4"
+ input: "Mean_5/reduction_indices"
+ output: "Mean_5"
+ mean_options {
+ keep_dims: false
+ }
+}
+input: "Relu_23"
+output: "Mean_5"
--- /dev/null
+# To check ONNX conversion is OK
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 1
+RULE "ONE_MEAN" $(op_count MEAN) '=' 1
+RULE "NO_TRANSPOSE" $(op_count TRANSPOSE) '=' 0
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_MUL" $(op_count MUL) '=' 0
--- /dev/null
+operand {
+ name: "input0"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 3
+ dim: 32
+ dim: 32
+ }
+}
+operand {
+ name: "Const_95"
+ type: INT32
+ shape {
+ dim: 4
+ dim: 2
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "0"
+ arg: "0"
+ arg: "0"
+ arg: "1"
+ arg: "1"
+ arg: "1"
+ arg: "1"
+ }
+}
+operand {
+ name: "Pad"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 3
+ dim: 34
+ dim: 34
+ }
+}
+operand {
+ name: "transpose_158/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "2"
+ arg: "3"
+ arg: "1"
+ }
+}
+operand {
+ name: "transpose_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 34
+ dim: 34
+ dim: 3
+ }
+}
+operand {
+ name: "convolution"
+ type: FLOAT32
+ shape {
+ dim: 16
+ dim: 3
+ dim: 3
+ dim: 3
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+}
+operand {
+ name: "convolution_41"
+ type: FLOAT32
+ shape {
+ dim: 16
+ }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "0.1"
+ }
+}
+operand {
+ name: "convolution1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operand {
+ name: "transpose_159/perm"
+ type: INT32
+ shape {
+ dim: 4
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ arg: "3"
+ arg: "1"
+ arg: "2"
+ }
+}
+operand {
+ name: "transpose_2"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operand {
+ name: "batchnorm/mul"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0.001"
+ }
+}
+operand {
+ name: "batchnorm/mul_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operand {
+ name: "batchnorm/sub"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 1
+ dim: 1
+ }
+ filler {
+ tag: "explicit"
+ arg: "0.0"
+ }
+}
+operand {
+ name: "batchnorm/add_1"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operand {
+ name: "clip_by_value_9/Minimum/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "6"
+ }
+}
+operand {
+ name: "clip_by_value/Minimum"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operand {
+ name: "clip_by_value_9/y"
+ type: FLOAT32
+ shape {
+ }
+ filler {
+ tag: "explicit"
+ arg: "0"
+ }
+}
+operand {
+ name: "clip_by_value"
+ type: FLOAT32
+ shape {
+ dim: 1
+ dim: 16
+ dim: 16
+ dim: 16
+ }
+}
+operation {
+ type: "Pad"
+ input: "input0"
+ input: "Const_95"
+ output: "Pad"
+}
+operation {
+ type: "Transpose"
+ input: "Pad"
+ input: "transpose_158/perm"
+ output: "transpose_1"
+}
+operation {
+ type: "Conv2D"
+ input: "transpose_1"
+ input: "convolution"
+ input: "convolution_41"
+ output: "convolution1"
+ conv2d_options {
+ padding: VALID
+ stride_w: 2
+ stride_h: 2
+ activation: NONE
+ dilation_w_factor: 1
+ dilation_h_factor: 1
+ }
+}
+operation {
+ type: "Transpose"
+ input: "convolution1"
+ input: "transpose_159/perm"
+ output: "transpose_2"
+}
+operation {
+ type: "Mul"
+ input: "transpose_2"
+ input: "batchnorm/mul"
+ output: "batchnorm/mul_1"
+ mul_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Add"
+ input: "batchnorm/mul_1"
+ input: "batchnorm/sub"
+ output: "batchnorm/add_1"
+ add_options {
+ activation: NONE
+ }
+}
+operation {
+ type: "Minimum"
+ input: "batchnorm/add_1"
+ input: "clip_by_value_9/Minimum/y"
+ output: "clip_by_value/Minimum"
+}
+operation {
+ type: "Maximum"
+ input: "clip_by_value/Minimum"
+ input: "clip_by_value_9/y"
+ output: "clip_by_value"
+}
+input: "input0"
+output: "clip_by_value"
--- /dev/null
+# To check ONNX conversion is OK
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "CONV_EXIST" $(op_count CONV_2D) '=' 1
+RULE "NO_TRANSPOSE" $(op_count TRANSPOSE) '=' 0
+RULE "NO_ADD" $(op_count ADD) '=' 0
+RULE "NO_MUL" $(op_count MUL) '=' 0
+RULE "NO_RELU6" $(op_count RELU6) '=' 0
+RULE "NO_MINIMUM" $(op_count MINIMUM) '=' 0
+RULE "NO_MAXIMUM" $(op_count MAXIMUM) '=' 0
--- /dev/null
+# Recipe for StridedSlice that will converted to Reshape by SubstituteStridedSliceToReshapePass
+#
+# shrink_axis_mask will remove axis 0
+
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 1 dim: 10 dim: 1 dim: 4 }
+}
+operand {
+ name: "begin"
+ type: INT32
+ shape { dim: 4 }
+ filler {
+ tag: "explicit"
+ arg: "0" arg: "0" arg: "0" arg: "0"
+ }
+}
+operand {
+ name: "end"
+ type: INT32
+ shape { dim: 4 }
+ filler {
+ tag: "explicit"
+ arg: "1" arg: "10" arg: "1" arg: "100"
+ }
+}
+operand {
+ name: "strides"
+ type: INT32
+ shape { dim: 4 }
+ filler {
+ tag: "explicit"
+ arg: "1" arg: "1" arg: "1" arg: "1"
+ }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim:10 dim: 1 dim: 4}
+}
+operation {
+ type: "StridedSlice"
+ strided_slice_options {
+ begin_mask: 0
+ end_mask: 0
+ ellipsis_mask: 0
+ new_axis_mask: 0
+ shrink_axis_mask: 1
+ }
+ input: "ifm"
+ input: "begin"
+ input: "end"
+ input: "strides"
+ output: "ofm"
+}
+input: "ifm"
+output: "ofm"
--- /dev/null
+# To check if Add and Mul are fused to Convolution op
+
+RULE "VERIFY_FILE_FORMAT" $(verify_file_format) '=' 1
+
+RULE "RESHAPE_EXIST" $(op_count RESHAPE) '=' 1
+RULE "NO_STRIDEDSLICE" $(op_count STRIDEDSLICE) '=' 0
--- /dev/null
+import tensorflow as tf
+
+x_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=[], name="HoleX")
+y_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=[], name="HoleY")
+z_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=[], name="HoleZ")
+
+
+def fn01(a, b):
+ return tf.math.multiply(a, b, name="Hole0M")
+
+
+def fn02(a, b):
+ return tf.math.add(a, b, name="Hole0A")
+
+
+def fn1(c, x, y, z):
+ return tf.cond(c, lambda: fn01(x, y), lambda: fn02(y, z), name="Cond0")
+
+
+def fn2(a, b):
+ return tf.math.add(a, b, name="HoleA")
+
+
+pr_ = tf.compat.v1.placeholder(tf.bool, shape=[], name="HoleC")
+op_ = tf.cond(pr_, lambda: fn1(pr_, x_, y_, z_), lambda: fn2(y_, z_), name="Cond")
+re_ = tf.identity(op_, name="HoleR")
minSdkVersion 26
targetSdkVersion 29
versionCode 1
- versionName "1.15.0"
+ versionName "1.17.0"
externalNativeBuild {
ndkBuild {
return()
endif(NOT TensorFlowLite_FOUND)
-add_subdirectory(port)
-
file(GLOB_RECURSE SOURCES "src/*.cpp")
file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
add_library(nnfw_lib_tflite STATIC ${SOURCES})
set_target_properties(nnfw_lib_tflite PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(nnfw_lib_tflite PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
-target_link_libraries(nnfw_lib_tflite PUBLIC tensorflow-lite-ex)
+target_link_libraries(nnfw_lib_tflite PUBLIC tensorflow-lite)
target_link_libraries(nnfw_lib_tflite PUBLIC nnfw_lib_misc)
target_link_libraries(nnfw_lib_tflite PRIVATE ${LIB_PTHREAD} dl)
target_link_libraries(nnfw_lib_tflite PRIVATE nnfw_common)
#define __NNFW_TFLITE_NNAPI_SESSION_H__
#include "Session.h"
-#include "tflite/ext/nnapi_delegate.h"
namespace nnfw
{
*/
NNAPISession(::tflite::Interpreter *interp) : _interp{interp}
{
- // Construct Graph from Interpreter
- // primary_subgraph: Experimental interface. Return 1st sugbraph
- _delegate.BuildGraph(&interp->primary_subgraph());
+ // DO NOTHING
}
public:
{
// Explicitly turn off T/F lite internal NNAPI delegation in order to use locally defined
// NNAPI delegation.
- _interp->UseNNAPI(false);
+ _interp->UseNNAPI(true);
if (kTfLiteOk != _interp->AllocateTensors())
{
* @brief Run the Invoke function of NNAPI delegate
* @return @c true if Invoke() is successful, otherwise @c false
*/
- bool run(void) override { return kTfLiteOk == _delegate.Invoke(&_interp->primary_subgraph()); }
+ bool run(void) override { return kTfLiteOk == _interp->Invoke(); }
/**
* @brief Tear down TfLite interpreter session
private:
::tflite::Interpreter *const _interp;
- nnfw::tflite::NNAPIDelegate _delegate;
};
} // namespace tflite
+++ /dev/null
-if(NOT SUPPORT_TFLITE_VERSION VERSION_EQUAL 1.13.1)
- return()
-endif(NOT SUPPORT_TFLITE_VERSION VERSION_EQUAL 1.13.1)
-
-file(GLOB_RECURSE SOURCES "src/*.cpp")
-
-add_library(tensorflow-lite-ex STATIC ${SOURCES})
-set_target_properties(tensorflow-lite-ex PROPERTIES POSITION_INDEPENDENT_CODE ON)
-target_include_directories(tensorflow-lite-ex PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
-target_link_libraries(tensorflow-lite-ex PUBLIC tensorflow-lite)
-target_link_libraries(tensorflow-lite-ex PUBLIC nnfw_lib_misc nnfw_lib_rua_shim)
-target_link_libraries(tensorflow-lite-ex PRIVATE ${LIB_PTHREAD} dl)
-target_link_libraries(tensorflow-lite-ex PRIVATE nnfw_common)
-target_link_libraries(tensorflow-lite-ex PRIVATE nnfw_coverage)
+++ /dev/null
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file CustomOps.h
- * @brief This file contains registration of custom operands
- * @ingroup COM_AI_RUNTIME
- */
-
-#ifndef __NNFW_TFLITE_EXT_KERNELS_CUSTOM_OP_H__
-#define __NNFW_TFLITE_EXT_KERNELS_CUSTOM_OP_H__
-
-#include "tensorflow/lite/context.h"
-#include "tflite/ext/kernels/SquaredDifference.h"
-
-namespace nnfw
-{
-namespace tflite
-{
-namespace custom
-{
-
-#define REGISTER_FUNCTION(Name) \
- TfLiteRegistration *Register_##Name(void) \
- { \
- static TfLiteRegistration r = {}; \
- r.init = Name::Init##Name; \
- r.free = Name::Free##Name; \
- r.prepare = Name::Prepare##Name; \
- r.invoke = Name::Eval##Name; \
- r.custom_name = #Name; \
- return &r; \
- }
-
-REGISTER_FUNCTION(SquaredDifference)
-
-#undef REGISTER_FUNCTION
-
-} // namespace custom
-} // namespace tflite
-} // namespace nnfw
-
-#endif // __NNFW_TFLITE_EXT_KERNELS_CUSTOM_OP_H__
+++ /dev/null
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file SquaredDifference.h
- * @brief This file contains SquaredDifference namespace and SquaredDifference function
- * definitions
- * @ingroup COM_AI_RUNTIME
- */
-
-#ifndef __NNFW_TFLITE_EXT_KERNELS_SQUARED_DIFFERENCE_H__
-#define __NNFW_TFLITE_EXT_KERNELS_SQUARED_DIFFERENCE_H__
-
-#include "tensorflow/lite/context.h"
-
-namespace nnfw
-{
-namespace tflite
-{
-namespace custom
-{
-namespace SquaredDifference
-{
-
-/**
- * @brief Initialize SquaredDifference operand using the contents of buffer
- * @param[in] context The TfLite context
- * @param[in] buffer The buffer with contents
- * @param[in] length The buffer length
- * @return The void pointer for user data
- */
-void *InitSquaredDifference(TfLiteContext *context, const char *buffer, size_t length);
-
-/**
- * @brief Release any memory it might have allocated via 'InitSquaredDifference'
- * @param[in] context The TfLite context
- * @param[in] buffer The buffer with contents
- * @return N/A
- */
-void FreeSquaredDifference(TfLiteContext *context, void *buffer);
-
-/**
- * @brief Prepare the SquaredDifference operand for execution
- * @param[in] context The TfLite context
- * @param[in] node The operand node
- * @return The TfLite status
- */
-TfLiteStatus PrepareSquaredDifference(TfLiteContext *context, TfLiteNode *node);
-
-/**
- * @brief Evaluation the SquaredDifference operand for execution
- * @param[in] context The TfLite context
- * @param[in] node The operand node
- * @return The TfLite status
- */
-TfLiteStatus EvalSquaredDifference(TfLiteContext *context, TfLiteNode *node);
-
-} // namespace SquaredDifference
-} // namespace custom
-} // namespace tflite
-} // namespace nnfw
-
-#endif // __NNFW_TFLITE_EXT_KERNELS_SQUARED_DIFFERENCE_H__
+++ /dev/null
-/* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOTE To minimize diff with upstream tensorflow, disable clang-format
-// clang-format off
-
-// NOTE This header is derived from the following file (in TensorFlow v1.13.1)
-// 'externals/tensorflow/tensorflow/lite/kernels/register.h'
-#ifndef __NNFW_TFLITE_EXT_KERNELS_REGISTER_H__
-#define __NNFW_TFLITE_EXT_KERNELS_REGISTER_H__
-
-#include <unordered_map>
-#include "tensorflow/lite/context.h"
-#include "tensorflow/lite/model.h"
-
-namespace nnfw {
-namespace tflite {
-
-class BuiltinOpResolver : public ::tflite::MutableOpResolver {
- public:
- BuiltinOpResolver();
-
- const TfLiteRegistration* FindOp(::tflite::BuiltinOperator op,
- int version) const override;
- const TfLiteRegistration* FindOp(const char* op, int version) const override;
-};
-
-} // namespace tflite
-} // namespace nnfw
-
-#endif // __NNFW_TFLITE_EXT_KERNELS_REGISTER_H__
-
-// clang-format on
+++ /dev/null
-/* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOTE To minimize diff with upstream tensorflow, disable clang-format
-// clang-format off
-
-// NOTE This header is derived from the following file (in TensorFlow v1.13.1)
-// 'externals/tensorflow/tensorflow/lite/nnapi_delegate.h'
-#ifndef __NNFW_TFLITE_EXT_NNAPI_DELEGATE_H__
-#define __NNFW_TFLITE_EXT_NNAPI_DELEGATE_H__
-
-#include "tensorflow/lite/allocation.h"
-#include "tensorflow/lite/c/c_api_internal.h"
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/core/subgraph.h"
-#include "tensorflow/lite/interpreter.h"
-
-struct ANeuralNetworksModel;
-struct ANeuralNetworksMemory;
-struct ANeuralNetworksCompilation;
-
-namespace nnfw {
-namespace tflite {
-
-class NNAPIAllocation : public ::tflite::MMAPAllocation {
- public:
- NNAPIAllocation(const char* filename, ::tflite::ErrorReporter* error_reporter);
- ~NNAPIAllocation();
-
- size_t offset(const void* ptr) const {
- auto signed_offset = reinterpret_cast<const uint8_t*>(ptr) -
- reinterpret_cast<const uint8_t*>(mmapped_buffer_);
-
- return static_cast<size_t>(signed_offset);
- }
-
- ANeuralNetworksMemory* memory() const { return handle_; }
- bool valid() const override { return handle_ != nullptr; }
-
- private:
- mutable ANeuralNetworksMemory* handle_ = nullptr;
-};
-
-class NNAPIDelegate {
- public:
- ~NNAPIDelegate();
-
- // Convert a tflite graph to NNAPI
- TfLiteStatus BuildGraph(::tflite::Subgraph* subgraph);
-
- // Run
- TfLiteStatus Invoke(::tflite::Subgraph* subgraph);
-
- // Whether the current platform supports NNAPI delegation.
- static bool IsSupported();
-
- private:
- // The NN API model handle
- ANeuralNetworksModel* nn_model_ = nullptr;
- // The NN API compilation handle
- ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
- // Model status
- TfLiteStatus model_status_ = kTfLiteOk;
-
- // List of state tensors for LSTM, RNN, SVDF.
- // NN API does not allow ops to maintain states across multiple
- // invocations. We need to manually create state input tensors from
- // corresponding state output tensors of TFLite operations, and map them
- // correctly.
- std::vector<int> model_states_inputs_; // holds NNAPI operand ids
- std::vector<int> model_states_outputs_; // holds TFLite tensor ids
-};
-
-} // namespace tflite
-} // namespace nnfw
-
-#endif // __NNFW_TFLITE_EXT_NNAPI_DELEGATE_H__
-
-// clang-format on
+++ /dev/null
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "tflite/ext/kernels/SquaredDifference.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-#include <iostream>
-
-namespace nnfw
-{
-namespace tflite
-{
-namespace custom
-{
-namespace SquaredDifference
-{
-
-void *InitSquaredDifference(TfLiteContext *, const char *, size_t) { return nullptr; }
-
-void FreeSquaredDifference(TfLiteContext *, void *) {}
-
-TfLiteStatus PrepareSquaredDifference(TfLiteContext *context, TfLiteNode *node)
-{
- TF_LITE_ENSURE_EQ(context, ::tflite::NumInputs(node), 2);
- TF_LITE_ENSURE_EQ(context, ::tflite::NumOutputs(node), 1);
-
- const TfLiteTensor *input1 = ::tflite::GetInput(context, node, 0);
- const TfLiteTensor *input2 = ::tflite::GetInput(context, node, 1);
- TfLiteTensor *output = ::tflite::GetOutput(context, node, 0);
-
- TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
- TF_LITE_ENSURE_EQ(context, input1->type, output->type);
-
- return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input1->dims));
-}
-
-TfLiteStatus EvalSquaredDifference(TfLiteContext *context, TfLiteNode *node)
-{
-
- const TfLiteTensor *input1 = ::tflite::GetInput(context, node, 0);
- const TfLiteTensor *input2 = ::tflite::GetInput(context, node, 1);
-
- TfLiteTensor *output = ::tflite::GetOutput(context, node, 0);
-
- size_t elements = ::tflite::NumElements(input1);
-
- switch (input1->type)
- {
- case kTfLiteFloat32:
- {
- const float *in1 = input1->data.f;
- const float *in2 = input2->data.f;
- const float *in_end1 = in1 + elements;
- float *out = output->data.f;
-
- for (; in1 < in_end1; in1++, in2++, out++)
- *out = ((*in1 - *in2) * (*in1 - *in2));
-
- return kTfLiteOk;
- }
- case kTfLiteInt32:
- {
- const int *in1 = input1->data.i32;
- const int *in2 = input2->data.i32;
- const int *in_end1 = in1 + elements;
- int *out = output->data.i32;
-
- for (; in1 < in_end1; in1++, in2++, out++)
- *out = ((*in1 - *in2) * (*in1 - *in2));
-
- return kTfLiteOk;
- }
- case kTfLiteInt64:
- {
- const int64_t *in1 = input1->data.i64;
- const int64_t *in2 = input1->data.i64;
- const int64_t *in_end1 = in1 + elements;
- int64_t *out = output->data.i64;
-
- for (; in1 < in_end1; in1++, in2++, out++)
- *out = ((*in1 - *in2) * (*in1 - *in2));
-
- return kTfLiteOk;
- }
- default:
- {
- context->ReportError(context, "InputType is %d Unsupported", input1->type);
- return kTfLiteError;
- }
- }
-}
-
-} // namespace SquaredDifference
-} // namespace custom
-} // namespace tflite
-} // namespace nnfw
+++ /dev/null
-/* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOTE To minimize diff with upstream tensorflow, disable clang-format
-// clang-format off
-
-// NOTE This code is derived from the following file (in TensorFlow v1.13.1)
-// 'externals/tensorflow/tensorflow/lite/kernels/register.cc'
-#include "tflite/ext/kernels/register.h"
-#include "tensorflow/lite/util.h"
-#include "tflite/ext/kernels/CustomOps.h"
-
-namespace tflite {
-namespace ops {
-
-namespace custom {
-
-// Need additional external library for AUDIO_SPECTROGRAM
-//TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
-TfLiteRegistration* Register_LAYER_NORM_LSTM();
-TfLiteRegistration* Register_MFCC();
-TfLiteRegistration* Register_DETECTION_POSTPROCESS();
-TfLiteRegistration* Register_RELU_1();
-
-} // namespace custom
-}
-}
-
-namespace tflite {
-namespace ops {
-namespace builtin {
-
-TfLiteRegistration* Register_ABS();
-TfLiteRegistration* Register_RELU();
-TfLiteRegistration* Register_RELU_N1_TO_1();
-TfLiteRegistration* Register_RELU6();
-TfLiteRegistration* Register_TANH();
-TfLiteRegistration* Register_LOGISTIC();
-TfLiteRegistration* Register_AVERAGE_POOL_2D();
-TfLiteRegistration* Register_MAX_POOL_2D();
-TfLiteRegistration* Register_L2_POOL_2D();
-TfLiteRegistration* Register_CONV_2D();
-TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
-TfLiteRegistration* Register_SVDF();
-TfLiteRegistration* Register_RNN();
-TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
-TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN();
-TfLiteRegistration* Register_EMBEDDING_LOOKUP();
-TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE();
-TfLiteRegistration* Register_FULLY_CONNECTED();
-TfLiteRegistration* Register_LSH_PROJECTION();
-TfLiteRegistration* Register_HASHTABLE_LOOKUP();
-TfLiteRegistration* Register_SOFTMAX();
-TfLiteRegistration* Register_CONCATENATION();
-TfLiteRegistration* Register_ADD();
-TfLiteRegistration* Register_SPACE_TO_BATCH_ND();
-TfLiteRegistration* Register_DIV();
-TfLiteRegistration* Register_SUB();
-TfLiteRegistration* Register_BATCH_TO_SPACE_ND();
-TfLiteRegistration* Register_MUL();
-TfLiteRegistration* Register_L2_NORMALIZATION();
-TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION();
-TfLiteRegistration* Register_LSTM();
-TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
-TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
-TfLiteRegistration* Register_PAD();
-TfLiteRegistration* Register_PADV2();
-TfLiteRegistration* Register_RESHAPE();
-TfLiteRegistration* Register_RESIZE_BILINEAR();
-TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR();
-TfLiteRegistration* Register_SKIP_GRAM();
-TfLiteRegistration* Register_SPACE_TO_DEPTH();
-TfLiteRegistration* Register_GATHER();
-TfLiteRegistration* Register_TRANSPOSE();
-TfLiteRegistration* Register_MEAN();
-TfLiteRegistration* Register_SPLIT();
-TfLiteRegistration* Register_SPLIT_V();
-TfLiteRegistration* Register_SQUEEZE();
-TfLiteRegistration* Register_STRIDED_SLICE();
-TfLiteRegistration* Register_EXP();
-TfLiteRegistration* Register_TOPK_V2();
-TfLiteRegistration* Register_LOG();
-TfLiteRegistration* Register_LOG_SOFTMAX();
-TfLiteRegistration* Register_CAST();
-TfLiteRegistration* Register_DEQUANTIZE();
-TfLiteRegistration* Register_PRELU();
-TfLiteRegistration* Register_MAXIMUM();
-TfLiteRegistration* Register_MINIMUM();
-TfLiteRegistration* Register_ARG_MAX();
-TfLiteRegistration* Register_ARG_MIN();
-TfLiteRegistration* Register_GREATER();
-TfLiteRegistration* Register_GREATER_EQUAL();
-TfLiteRegistration* Register_LESS();
-TfLiteRegistration* Register_LESS_EQUAL();
-TfLiteRegistration* Register_FLOOR();
-TfLiteRegistration* Register_TILE();
-TfLiteRegistration* Register_NEG();
-TfLiteRegistration* Register_SUM();
-TfLiteRegistration* Register_REDUCE_PROD();
-TfLiteRegistration* Register_REDUCE_MAX();
-TfLiteRegistration* Register_REDUCE_MIN();
-TfLiteRegistration* Register_REDUCE_ANY();
-TfLiteRegistration* Register_SELECT();
-TfLiteRegistration* Register_SLICE();
-TfLiteRegistration* Register_SIN();
-TfLiteRegistration* Register_TRANSPOSE_CONV();
-TfLiteRegistration* Register_EXPAND_DIMS();
-TfLiteRegistration* Register_SPARSE_TO_DENSE();
-TfLiteRegistration* Register_EQUAL();
-TfLiteRegistration* Register_NOT_EQUAL();
-TfLiteRegistration* Register_SQRT();
-TfLiteRegistration* Register_RSQRT();
-TfLiteRegistration* Register_SHAPE();
-TfLiteRegistration* Register_POW();
-TfLiteRegistration* Register_FAKE_QUANT();
-TfLiteRegistration* Register_PACK();
-TfLiteRegistration* Register_ONE_HOT();
-TfLiteRegistration* Register_LOGICAL_OR();
-TfLiteRegistration* Register_LOGICAL_AND();
-TfLiteRegistration* Register_LOGICAL_NOT();
-TfLiteRegistration* Register_UNPACK();
-TfLiteRegistration* Register_FLOOR_DIV();
-TfLiteRegistration* Register_SQUARE();
-TfLiteRegistration* Register_ZEROS_LIKE();
-TfLiteRegistration* Register_FLOOR_MOD();
-TfLiteRegistration* Register_RANGE();
-TfLiteRegistration* Register_LEAKY_RELU();
-TfLiteRegistration* Register_SQUARED_DIFFERENCE();
-TfLiteRegistration* Register_FILL();
-TfLiteRegistration* Register_MIRROR_PAD();
-
-} // namespace builtin
-} // namespace ops
-} // namespace tflite
-
-namespace nnfw {
-namespace tflite {
-
-// Using namespace directive to minimize diff with upstream tensorflow
-using namespace ::tflite::ops::custom;
-using namespace ::tflite::ops::builtin;
-using namespace ::tflite;
-
-// Fix to use strict build option
-TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* /*node*/) {
- context->ReportError(
- context,
- "Regular TensorFlow ops are not supported by this interpreter. Make sure "
- "you invoke the Flex delegate before inference.");
- return kTfLiteError;
-}
-
-const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
- int version) const {
- return MutableOpResolver::FindOp(op, version);
-}
-
-const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
- int version) const {
- // Return the NULL Op for all ops whose name start with "Flex", allowing
- // the interpreter to delegate their execution.
- if (IsFlexOp(op)) {
- static TfLiteRegistration null_op{
- nullptr, nullptr, &UnsupportedTensorFlowOp,
- nullptr, nullptr, BuiltinOperator_CUSTOM,
- "Flex", 1};
- return &null_op;
- }
- return MutableOpResolver::FindOp(op, version);
-}
-
-BuiltinOpResolver::BuiltinOpResolver() {
- AddBuiltin(BuiltinOperator_ABS, Register_ABS());
- AddBuiltin(BuiltinOperator_RELU, Register_RELU());
- AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
- AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
- AddBuiltin(BuiltinOperator_TANH, Register_TANH());
- AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
- AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
- AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
- AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
- AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
- /* min_version */ 1,
- /* max_version */ 2);
- AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
- AddBuiltin(BuiltinOperator_RNN, Register_RNN());
- AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- Register_BIDIRECTIONAL_SEQUENCE_RNN());
- AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
- Register_UNIDIRECTIONAL_SEQUENCE_RNN());
- AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
- AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
- Register_EMBEDDING_LOOKUP_SPARSE());
- AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
- /* min_version */ 1,
- /* max_version */ 2);
- AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
- AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
- AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
- AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
- AddBuiltin(BuiltinOperator_ADD, Register_ADD());
- AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND());
- AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND());
- AddBuiltin(BuiltinOperator_MUL, Register_MUL());
- AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
- AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
- Register_LOCAL_RESPONSE_NORMALIZATION());
- AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
- /* max_version */ 2);
- AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- Register_BIDIRECTIONAL_SEQUENCE_LSTM());
- AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- Register_UNIDIRECTIONAL_SEQUENCE_LSTM());
- AddBuiltin(BuiltinOperator_PAD, Register_PAD());
- AddBuiltin(BuiltinOperator_PADV2, Register_PADV2());
- AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
- AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR());
- AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
- Register_RESIZE_NEAREST_NEIGHBOR());
- AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
- AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
- AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
- AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE());
- AddBuiltin(BuiltinOperator_MEAN, Register_MEAN());
- AddBuiltin(BuiltinOperator_DIV, Register_DIV());
- AddBuiltin(BuiltinOperator_SUB, Register_SUB());
- AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT());
- AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
- AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
- AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
- AddBuiltin(BuiltinOperator_EXP, Register_EXP());
- AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
- AddBuiltin(BuiltinOperator_LOG, Register_LOG());
- AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
- AddBuiltin(BuiltinOperator_CAST, Register_CAST());
- AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
- /* min_version */ 1,
- /* max_version */ 2);
- AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
- AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
- AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
- AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
- AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
- AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
- AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
- AddBuiltin(BuiltinOperator_LESS, Register_LESS());
- AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL());
- AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
- AddBuiltin(BuiltinOperator_NEG, Register_NEG());
- AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
- AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
- AddBuiltin(BuiltinOperator_SIN, Register_SIN());
- AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
- AddBuiltin(BuiltinOperator_TILE, Register_TILE());
- AddBuiltin(BuiltinOperator_SUM, Register_SUM());
- AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
- AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
- AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN());
- AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY());
- AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
- AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
- AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
- AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
- AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
- AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
- AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
- AddBuiltin(BuiltinOperator_POW, Register_POW());
- AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
- AddBuiltin(BuiltinOperator_PACK, Register_PACK());
- AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
- AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
- AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
- AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
- AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
- AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
- AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
- AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
- AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD());
- AddBuiltin(BuiltinOperator_RANGE, Register_RANGE());
- AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU());
- AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
- AddBuiltin(BuiltinOperator_FILL, Register_FILL());
- AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
-
- AddCustom("SquaredDifference", nnfw::tflite::custom::Register_SquaredDifference());
-
- // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
- // custom ops aren't always included by default.
- AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
- // Need additional external library for audio spectrogram
- //AddCustom("AudioSpectrogram",
- // tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
- AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
- AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
- AddCustom("TFLite_Detection_PostProcess",
- tflite::ops::custom::Register_DETECTION_POSTPROCESS());
-}
-
-} // namespace tflite
-} // namespace nnfw
+++ /dev/null
-/* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOTE To minimize diff with upstream tensorflow, disable clang-format
-// clang-format off
-
-// NOTE This code is derived from the following file (in TensorFlow v1.13.1)
-// 'externals/tensorflow/tensorflow/lite/nnapi_delegate.cc'
-#include "tflite/ext/nnapi_delegate.h"
-#include <fcntl.h>
-#include <sys/mman.h>
-#include <sys/stat.h>
-#include <sys/types.h>
-#include "tensorflow/lite/c/builtin_op_data.h"
-#include "tensorflow/lite/core/api/error_reporter.h"
-#include "tensorflow/lite/model.h"
-#include <rua/Shim.h>
-#include "NeuralNetworksExShim.h"
-
-#ifdef __ANDROID__
-#include <android/log.h>
-#include <sys/system_properties.h>
-#endif
-
-#include <memory>
-
-namespace nnfw {
-namespace tflite {
-
-void logError(const char* format, ...) {
- // stderr is convenient for native tests, but is not captured for apps
- va_list args_for_stderr;
- va_start(args_for_stderr, format);
- vfprintf(stderr, format, args_for_stderr);
- va_end(args_for_stderr);
- fprintf(stderr, "\n");
- fflush(stderr);
-#ifdef __ANDROID__
- // produce logcat output for general consumption
- va_list args_for_log;
- va_start(args_for_log, format);
- __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log);
- va_end(args_for_log);
-#endif
-}
-
-#define FATAL(...) \
- logError(__VA_ARGS__); \
- exit(1);
-
-// TODO(aselle): Change the error model to use status codes.
-#define CHECK_TFLITE_SUCCESS(x) \
- if (x != kTfLiteOk) { \
- FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \
- __LINE__); \
- }
-
-#define CHECK_NN(x) \
- if (x != ANEURALNETWORKS_NO_ERROR) { \
- FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \
- __LINE__); \
- }
-
-#define RETURN_ERROR_IF_TFLITE_FAILED(x) \
- if (x != kTfLiteOk) { \
- logError( \
- "Returning error since TFLite returned failure nnapi_delegate.cc:%d.", \
- __LINE__); \
- return kTfLiteError; \
- }
-
-#define RETURN_ERROR_IF_NN_FAILED(x) \
- if (x != ANEURALNETWORKS_NO_ERROR) { \
- logError( \
- "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \
- __LINE__); \
- return kTfLiteError; \
- }
-
-// Tracking of NNAPI operand ids
-static const int64_t kOperandIdNotSet = -1;
-static const int64_t kOperandNotNeeded = -2;
-
-namespace {
-
-int32_t GetAndroidSdkVersion() {
-#ifdef __ANDROID__
- const char* sdkProp = "ro.build.version.sdk";
- char sdkVersion[PROP_VALUE_MAX];
- int length = __system_property_get(sdkProp, sdkVersion);
- if (length != 0) {
- for (int i = 0; i < length; ++i) {
- int digit = sdkVersion[i] - '0';
- if (digit < 0 || digit > 9) {
- // Non-numeric SDK version, assume it's higher then expected;
- return 0xFFFF;
- }
- }
- // NOTE use std::strtol instead of atoi: security issue
- return std::strtol(sdkVersion, NULL, 0);
- }
- FATAL("No %s prop", sdkProp);
-#endif // __ANDROID__
- return 0;
-}
-
-int32_t GetAndroidSdkVersionCached() {
- static int32_t androidSdkVersion = GetAndroidSdkVersion();
- return androidSdkVersion;
-}
-
-// WORKAROUND Some model have dimension zero
-// Consider scalar as vector size 1
-static const uint32_t dimension_for_scalar[1] = {1};
-
-} // namespace
-
-NNAPIAllocation::NNAPIAllocation(const char* filename,
- ::tflite::ErrorReporter* error_reporter)
- : MMAPAllocation(filename, error_reporter) {
- if (mmapped_buffer_ != MAP_FAILED)
- CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ,
- mmap_fd_, 0, &handle_));
-}
-
-NNAPIAllocation::~NNAPIAllocation() {
- if (handle_) {
- ANeuralNetworksMemory_free(handle_);
- }
-}
-
-NNAPIDelegate::~NNAPIDelegate() {
- if (nn_compiled_model_) {
- ANeuralNetworksCompilation_free(nn_compiled_model_);
- nn_compiled_model_ = nullptr;
- }
- if (nn_model_) {
- ANeuralNetworksModel_free(nn_model_);
- nn_model_ = nullptr;
- // TODO(aselle): Is this thread-safe and callable multiple times?
- }
- // ANeuralNetworksShutdown();
-}
-
-// Adds the tensors of the subgraph to the NN API model.
-TfLiteStatus addTensorOperands(::tflite::Subgraph* subgraph,
- ANeuralNetworksModel* nn_model,
- uint32_t* no_of_operands_added,
- std::vector<int64_t>* nnapi_ids) {
- uint32_t next_id = 0;
- // Allocate temporary buffer to save casted boolean tensor
- std::unordered_map<size_t, std::unique_ptr<uint8_t[]>> const_boolean_tensors;
-
- for (size_t i = 0; i < subgraph->tensors_size(); i++) {
- // Skip temporaries and RNN back-edges.
- if ((*nnapi_ids)[i] == kOperandNotNeeded) continue;
-
- (*nnapi_ids)[i] = int64_t(next_id);
-
- int32_t nn_type = 0;
- // NNAPI requires 32-bit float scale to be zero, tflite doesn't care
- float scale = 0.0f;
- int32_t zeroPoint = 0;
- TfLiteTensor* tensor = subgraph->tensor(i);
- switch (tensor->type) {
- case kTfLiteNoType:
- // Tensors added during initialization of Ops don't have a type yet and
- // should not be registered with the NNAPI.
- continue;
- case kTfLiteFloat32:
- nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
- break;
- case kTfLiteUInt8:
- // NNAPI uses ANEURALNETWORKS_TENSOR_QUANT8_ASYMM to represent uint8 type
- // ex. ANEURALNETWORKS_CAST
- nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
- scale = tensor->params.scale;
- // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM type requires scale > 0,
- // zeroPoint >= 0 and zeroPoint <= 255
- scale = (scale == 0.0f) ? 1.0f : scale;
- zeroPoint = tensor->params.zero_point;
- break;
- case kTfLiteInt32:
- nn_type = ANEURALNETWORKS_TENSOR_INT32;
- scale = tensor->params.scale;
- zeroPoint = tensor->params.zero_point;
- break;
- case kTfLiteBool:
- // Workaround to pass bool type under NNAPI
- // Use bool type using ANEURALNETWORKS_TENSOR_QUANT8_ASYMM with scale = 1.0f and zero_point = 0
- nn_type = ANEURALNETWORKS_TENSOR_BOOL8;
- break;
- default:
- logError("Unsupported tensor type %d", tensor->type);
- return kTfLiteError;
- }
- if (tensor->dims->size == 0) {
- // WORKAROUND Some model have dimension zero
- switch (tensor->type) {
- case kTfLiteFloat32:
- nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
- break;
- case kTfLiteInt32:
- nn_type = ANEURALNETWORKS_TENSOR_INT32;
- break;
- default:
- logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)",
- i, tensor->name);
- return kTfLiteError;
- }
- }
- if (tensor->dims->size > 4) {
- logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)",
- i, tensor->name);
- return kTfLiteError;
- }
- // TODO(aselle): Note, many of these are intermediate results. Do I need
- // to ever specify these sizes. I am currently below doing setValue
- // on all of them, but I shouldn't in the future.
- // Answer(jeanluc): If all the operators can set the dimension correctly,
- // you won't need to.
- ANeuralNetworksOperandType operand_type{
- nn_type, static_cast<uint32_t>(tensor->dims->size),
- reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
- if (tensor->dims->size == 0) {
- // WORKAROUND Some model have dimension zero
- // Consider scalar as vector size 1
- operand_type.dimensions = dimension_for_scalar;
- operand_type.dimensionCount = 1;
- }
- RETURN_ERROR_IF_NN_FAILED(
- ANeuralNetworksModel_addOperand(nn_model, &operand_type));
- // TODO(aselle): Based on Michael's suggestion, limiting this to read
- // only memory
- if (tensor->allocation_type == kTfLiteMmapRo) {
- if (tensor->type == kTfLiteBool)
- {
- // ANEURALNETWORKS_TENSOR_BOOL8 tensor element size is 8 bits
- size_t elements = tensor->bytes / sizeof(bool);
- const_boolean_tensors[i] = std::make_unique<uint8_t[]>(elements);
- for (size_t idx = 0; idx < elements; idx++)
- {
- const_boolean_tensors[i].get()[idx] = (tensor->data.b[idx] ? 0x00 : 0xff);
- }
- RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue(
- nn_model, next_id, const_boolean_tensors[i].get(), tensor->bytes));
- }
- else if (const NNAPIAllocation* alloc = dynamic_cast<const NNAPIAllocation*>(
- static_cast<const ::tflite::Allocation*>(tensor->allocation))) {
- RETURN_ERROR_IF_NN_FAILED(
- ANeuralNetworksModel_setOperandValueFromMemory(
- nn_model, next_id, alloc->memory(),
- alloc->offset(tensor->data.raw), tensor->bytes));
- } else {
- RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue(
- nn_model, next_id, tensor->data.raw, tensor->bytes));
- }
- } else if (tensor->bytes == 0) {
- // These size 0 tensors are optional tensors reserved.
- RETURN_ERROR_IF_NN_FAILED(
- ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0));
- }
-
- ++next_id;
- }
- *no_of_operands_added = next_id;
- return kTfLiteOk;
-}
-
-void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count,
- std::vector<uint32_t>* into,
- const std::vector<int64_t>& map) {
- for (size_t i = 0; i < from_ids_count; i++) {
- int from_id = from_ids_buf[i];
- if (from_id == kOptionalTensor) {
- into->push_back(from_id);
- } else {
- into->push_back(map[from_id]);
- }
- }
-}
-
-// Adds the operations and their parameters to the NN API model.
-// 'next-id' is the operand ID of the next operand of the model.
-TfLiteStatus AddOpsAndParams(
- ::tflite::Subgraph* subgraph, ANeuralNetworksModel* nn_model,
- uint32_t next_id, std::vector<int>* model_state_inputs,
- std::vector<int>* model_state_outputs,
- const std::vector<int64_t>& tensor_id_to_nnapi_id) {
- for (size_t i = 0; i < subgraph->nodes_size(); i++) {
- const auto* node_and_registration = subgraph->node_and_registration(i);
- const TfLiteNode& node = node_and_registration->first;
- const TfLiteRegistration& registration = node_and_registration->second;
- ::tflite::BuiltinOperator builtin =
- static_cast<::tflite::BuiltinOperator>(registration.builtin_code);
-
- // Add the parameters.
- std::vector<uint32_t> augmented_inputs, augmented_outputs;
- MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs,
- tensor_id_to_nnapi_id);
- MapAndAddTensorIds(node.outputs->data, node.outputs->size,
- &augmented_outputs, tensor_id_to_nnapi_id);
-
- auto add_scalar_int32 = [&nn_model, &augmented_inputs,
- &next_id](int value) {
- // Fix to use strict build option
- ANeuralNetworksOperandType operand_type{}; operand_type.type = ANEURALNETWORKS_INT32;
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
- sizeof(int32_t)))
- augmented_inputs.push_back(next_id++);
- };
-
- auto add_scalar_float32 = [&nn_model, &augmented_inputs,
- &next_id](float value) {
- // Fix to use strict build option
- ANeuralNetworksOperandType operand_type{}; operand_type.type = ANEURALNETWORKS_FLOAT32;
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value,
- sizeof(float)))
- augmented_inputs.push_back(next_id++);
- };
-
- auto add_vector_int32 = [&](const int* values, uint32_t num_values) {
- // Fix to use strict build option
- ANeuralNetworksOperandType operand_type{};
- operand_type.type = ANEURALNETWORKS_TENSOR_INT32;
- operand_type.dimensionCount = 1;
- operand_type.dimensions = &num_values;
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(
- nn_model, next_id, values, sizeof(int32_t) * num_values));
- augmented_inputs.push_back(next_id++);
- };
-
- // Handle state tensors of RNN, LSTM, SVDF.
- // For each state_out tensor, a corresponding state_in operand needs to be
- // created for NNAPI.
- auto duplicate_state_tensor_float32 =
- [subgraph, &nn_model, &next_id, &augmented_inputs, &model_state_inputs,
- &model_state_outputs](int tensor_id) {
- const TfLiteTensor* tensor = subgraph->tensor(tensor_id);
- ANeuralNetworksOperandType operand_type{
- ANEURALNETWORKS_TENSOR_FLOAT32,
- static_cast<uint32_t>(tensor->dims->size),
- reinterpret_cast<uint32_t*>(tensor->dims->data),
- tensor->params.scale, tensor->params.zero_point};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
- augmented_inputs.push_back(next_id);
- model_state_inputs->push_back(next_id);
- model_state_outputs->push_back(tensor_id);
- next_id++;
- };
- auto check_and_add_activation = [&add_scalar_int32](int activation) {
- if (activation > kTfLiteActRelu6) {
- logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
- return kTfLiteError;
- }
- add_scalar_int32(activation);
- return kTfLiteOk;
- };
-
- auto add_add_params = [&add_scalar_int32](void* data) {
- auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
- if (builtin->activation > kTfLiteActRelu6) {
- logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
- return kTfLiteError;
- }
- add_scalar_int32(builtin->activation);
- return kTfLiteOk;
- };
-
- auto add_pooling_params = [&add_scalar_int32,
- &check_and_add_activation](void* data) {
- auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
- add_scalar_int32(builtin->padding);
- add_scalar_int32(builtin->stride_width);
- add_scalar_int32(builtin->stride_height);
- add_scalar_int32(builtin->filter_width);
- add_scalar_int32(builtin->filter_height);
- return check_and_add_activation(builtin->activation);
- };
-
- auto add_convolution_params = [&add_scalar_int32,
- &check_and_add_activation](void* data) {
- auto builtin = reinterpret_cast<TfLiteConvParams*>(data);
- add_scalar_int32(builtin->padding);
- add_scalar_int32(builtin->stride_width);
- add_scalar_int32(builtin->stride_height);
- return check_and_add_activation(builtin->activation);
- };
-
- auto add_depthwise_conv_params = [&add_scalar_int32,
- &check_and_add_activation](void* data) {
- auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(data);
- add_scalar_int32(builtin->padding);
- add_scalar_int32(builtin->stride_width);
- add_scalar_int32(builtin->stride_height);
- add_scalar_int32(builtin->depth_multiplier);
- return check_and_add_activation(builtin->activation);
- };
-
- auto add_fully_connected_params = [&check_and_add_activation](void* data) {
- auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
- return check_and_add_activation(builtin->activation);
- };
-
- auto add_concatenation_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
- add_scalar_int32(builtin->axis);
- if (builtin->activation != kTfLiteActNone) {
- logError("Concatenation does not support fused activation in NNAPI");
- return kTfLiteError;
- }
- return kTfLiteOk;
- };
-
- auto add_softmax_params = [&add_scalar_float32](void* data) {
- auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>(data);
- add_scalar_float32(builtin->beta);
- };
-
- auto add_space_to_depth_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteSpaceToDepthParams*>(data);
- add_scalar_int32(builtin->block_size);
- };
-
- auto add_lstm_params = [&add_scalar_int32,
- &add_scalar_float32](void* data) {
- auto builtin = reinterpret_cast<TfLiteLSTMParams*>(data);
- add_scalar_int32(builtin->activation);
- add_scalar_float32(builtin->cell_clip);
- add_scalar_float32(builtin->proj_clip);
- };
-
- // LSTM in NNAPI requires scratch tensor as an output operand.
- auto add_lstm_scratch_tensor_float32 = [subgraph, &node, &nn_model,
- &next_id, &augmented_outputs]() {
- if (node.temporaries->size == 0) return;
- int scratch_buffer_index = node.temporaries->data[0];
- const TfLiteTensor* tensor = subgraph->tensor(scratch_buffer_index);
- ANeuralNetworksOperandType operand_type{
- ANEURALNETWORKS_TENSOR_FLOAT32,
- static_cast<uint32_t>(tensor->dims->size),
- reinterpret_cast<uint32_t*>(tensor->dims->data), tensor->params.scale,
- tensor->params.zero_point};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
- augmented_outputs.insert(augmented_outputs.begin(), next_id++);
- };
-
- auto add_mean_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteReducerParams*>(data);
- add_scalar_int32(builtin->keep_dims);
- };
-
- auto add_svdf_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteSVDFParams*>(data);
- add_scalar_int32(builtin->rank);
- add_scalar_int32(builtin->activation);
- };
-
- auto add_rnn_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteRNNParams*>(data);
- add_scalar_int32(builtin->activation);
- };
-
- auto add_squeeze_params = [&](void* data) {
- const auto* builtin = reinterpret_cast<TfLiteSqueezeParams*>(data);
- // Note that we add the squeeze dimensions even if the dimensions were
- // unspecified (empty), as NNAPI requires the operand.
- add_vector_int32(builtin->squeeze_dims,
- static_cast<uint32_t>(builtin->num_squeeze_dims));
- };
-
- // Handle optional input tensors.
- auto add_optional_tensors = [&nn_model, &augmented_inputs,
- &next_id](int nn_type) {
- for (size_t idx = 0; idx < augmented_inputs.size(); idx++) {
- // Fix to use strict build option
- if (augmented_inputs[idx] == static_cast<uint32_t>(kOptionalTensor)) {
- const std::vector<uint32_t> dim = {0, 0};
- ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id,
- nullptr, 0))
- augmented_inputs[idx] = next_id++;
- }
- }
- };
-
- int nnapi_version = 10;
-#include "nnapi_delegate_ex_AddOpsAndParams_lambda.inc"
-
- // Fix to use strict build option
- ANeuralNetworksOperationType nn_op_type = -1;
-
- // Using namespace directive to minimize diff with upstream tensorflow
- namespace tflite = ::tflite;
-
- switch (builtin) {
- case tflite::BuiltinOperator_ADD:
- nn_op_type = ANEURALNETWORKS_ADD;
- RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
- break;
- case tflite::BuiltinOperator_MUL:
- nn_op_type = ANEURALNETWORKS_MUL;
- RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
- break;
- case tflite::BuiltinOperator_AVERAGE_POOL_2D:
- RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
- break;
- case tflite::BuiltinOperator_MAX_POOL_2D:
- RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
- break;
- case tflite::BuiltinOperator_L2_POOL_2D:
- RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
- break;
- case tflite::BuiltinOperator_CONV_2D: {
- auto builtin = reinterpret_cast<TfLiteConvParams*>(node.builtin_data);
- if (builtin->dilation_width_factor != 1 ||
- builtin->dilation_height_factor != 1 || node.inputs->size != 3) {
- logError("NNAPI does not support dilated Conv2D.");
- return kTfLiteError;
- }
- }
- RETURN_ERROR_IF_TFLITE_FAILED(
- add_convolution_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_CONV_2D;
- break;
- case tflite::BuiltinOperator_RELU:
- nn_op_type = ANEURALNETWORKS_RELU;
- break;
- case tflite::BuiltinOperator_RELU_N1_TO_1:
- nn_op_type = ANEURALNETWORKS_RELU1;
- break;
- case tflite::BuiltinOperator_RELU6:
- nn_op_type = ANEURALNETWORKS_RELU6;
- break;
- case tflite::BuiltinOperator_TANH:
- nn_op_type = ANEURALNETWORKS_TANH;
- break;
- case tflite::BuiltinOperator_FLOOR:
- nn_op_type = ANEURALNETWORKS_FLOOR;
- break;
- case tflite::BuiltinOperator_LOGISTIC:
- nn_op_type = ANEURALNETWORKS_LOGISTIC;
- break;
- case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
- RETURN_ERROR_IF_TFLITE_FAILED(
- add_depthwise_conv_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
- break;
- case tflite::BuiltinOperator_CONCATENATION:
- RETURN_ERROR_IF_TFLITE_FAILED(
- add_concatenation_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_CONCATENATION;
- break;
- case tflite::BuiltinOperator_SOFTMAX:
- add_softmax_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_SOFTMAX;
- break;
- case tflite::BuiltinOperator_FULLY_CONNECTED:
- RETURN_ERROR_IF_TFLITE_FAILED(
- add_fully_connected_params(node.builtin_data));
- nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
- break;
- case tflite::BuiltinOperator_RESHAPE:
- if (node.inputs->size != 2) {
- logError("NNAPI only supports 2-input RESHAPE");
- return kTfLiteError;
- }
- nn_op_type = ANEURALNETWORKS_RESHAPE;
- // add_reshape_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_RESIZE_BILINEAR:
- add_resize_bilinear_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_RESIZE_BILINEAR;
- break;
- case tflite::BuiltinOperator_SPACE_TO_DEPTH:
- add_space_to_depth_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH;
- break;
- case tflite::BuiltinOperator_LSTM: {
- if (node.inputs->size + /* no of params */ 3 != 21) {
- logError("NNAPI only supports 21-input LSTMs");
- return kTfLiteError;
- }
- duplicate_state_tensor_float32(
- node.outputs->data[/*kOutputStateTensor*/ 0]);
- duplicate_state_tensor_float32(
- node.outputs->data[/*kCellStateTensor*/ 1]);
- add_lstm_params(node.builtin_data);
- add_lstm_scratch_tensor_float32();
- add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32);
- nn_op_type = ANEURALNETWORKS_LSTM;
- break;
- }
- case tflite::BuiltinOperator_DEQUANTIZE:
- nn_op_type = ANEURALNETWORKS_DEQUANTIZE;
- break;
- case tflite::BuiltinOperator_SVDF: {
- duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]);
- add_svdf_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_SVDF;
- break;
- }
- case tflite::BuiltinOperator_RNN: {
- duplicate_state_tensor_float32(
- node.outputs->data[/*kHiddenStateTensor*/ 0]);
- add_rnn_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_RNN;
- break;
- }
- case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
- nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP;
- break;
- case tflite::BuiltinOperator_PAD:
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_PAD;
- break;
- case tflite::BuiltinOperator_MEAN:
- nnapi_version = 11; // require NNAPI 1.1
- add_mean_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_MEAN;
- break;
- case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
- nn_op_type = ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION;
- add_lrn_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_DIV:
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_DIV;
- RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
- reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation));
- break;
- case tflite::BuiltinOperator_SUB:
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_SUB;
- RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
- reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation));
- break;
- case tflite::BuiltinOperator_SQUEEZE:
- nnapi_version = 11; // requires NNAPI 1.1
- add_squeeze_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_SQUEEZE;
- break;
- case tflite::BuiltinOperator_TRANSPOSE:
- // The permutation input tensor value dictates the output dimensions.
- // TODO(b/110888333): Support dynamically-sized tensors in delegates.
- if ((node.inputs->size > 1) &&
- (subgraph->tensor(node.inputs->data[1])->allocation_type !=
- kTfLiteMmapRo)) {
- logError("NNAPI does not yet support dynamic tensors.");
- return kTfLiteError;
- }
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_TRANSPOSE;
- break;
- case tflite::BuiltinOperator_L2_NORMALIZATION:
- nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION;
- if (reinterpret_cast<TfLiteL2NormParams*>(node.builtin_data)
- ->activation != kTfLiteActNone) {
- logError(
- "NNAPI does not support L2Normalization with fused activations");
- return kTfLiteError;
- }
- if ((node.inputs->size > 0) &&
- (subgraph->tensor(node.inputs->data[0])->dims->size != 4)) {
- logError("NNAPI only supports input rank 4 for L2Normalization");
- return kTfLiteError;
- }
- break;
- case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
- if (subgraph->tensor(node.outputs->data[0])->type != kTfLiteFloat32) {
- logError("NNAPI only support HASHTABLE_LOOKUP with float32 output",
- builtin);
- return kTfLiteError;
- }
- nn_op_type = ANEURALNETWORKS_HASHTABLE_LOOKUP;
- break;
- case tflite::BuiltinOperator_SLICE:
- nn_op_type = ANEURALNETWORKS_SLICE;
- break;
- case tflite::BuiltinOperator_STRIDED_SLICE:
- add_strided_slice_params(node.builtin_data);
- nn_op_type = ANEURALNETWORKS_STRIDED_SLICE;
- break;
- case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_SPACE_TO_BATCH_ND;
- break;
- case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
- nnapi_version = 11; // require NNAPI 1.1
- nn_op_type = ANEURALNETWORKS_BATCH_TO_SPACE_ND;
- check_batch_to_space_params();
- break;
- case tflite::BuiltinOperator_CAST:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_CAST;
- break;
- case tflite::BuiltinOperator_TOPK_V2:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_TOPK_V2;
- break;
- case tflite::BuiltinOperator_GREATER:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_GREATER;
- break;
- case tflite::BuiltinOperator_GREATER_EQUAL:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_GREATER_EQUAL;
- break;
- case tflite::BuiltinOperator_LESS:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LESS;
- break;
- case tflite::BuiltinOperator_LESS_EQUAL:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LESS_EQUAL;
- break;
- case tflite::BuiltinOperator_GATHER:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_GATHER;
- add_gather_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_SPLIT:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_SPLIT;
- add_split_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_NEG:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_NEG;
- break;
- case tflite::BuiltinOperator_EXP:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_EXP;
- break;
- case tflite::BuiltinOperator_TRANSPOSE_CONV:
- add_transpose_conv_params(node.builtin_data);
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_TRANSPOSE_CONV_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- case tflite::BuiltinOperator_PRELU:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_PRELU;
- break;
- case tflite::BuiltinOperator_ARG_MAX:
- check_arg_max_input(node.builtin_data);
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_ARGMAX;
- break;
- case tflite::BuiltinOperator_PACK:
- add_pack_ex_params(node.builtin_data);
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_PACK_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- case tflite::BuiltinOperator_UNPACK:
- add_unpack_ex_params(node.builtin_data);
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_UNPACK_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- case tflite::BuiltinOperator_SQRT:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_SQRT;
- break;
- case tflite::BuiltinOperator_RSQRT:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_RSQRT;
- break;
- case tflite::BuiltinOperator_EQUAL:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_EQUAL;
- break;
- case tflite::BuiltinOperator_NOT_EQUAL:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_NOT_EQUAL;
- break;
- case tflite::BuiltinOperator_SUM:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_REDUCE_SUM;
- add_reducer_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_REDUCE_ANY:
- add_reducer_params(node.builtin_data);
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_REDUCE_ANY;
- break;
- case tflite::BuiltinOperator_REDUCE_MAX:
- add_reducer_params(node.builtin_data);
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_REDUCE_MAX;
- break;
- case tflite::BuiltinOperator_REDUCE_MIN:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_REDUCE_MIN;
- add_reducer_params(node.builtin_data);
- break;
- case tflite::BuiltinOperator_LOG:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LOG;
- break;
- case tflite::BuiltinOperator_LOGICAL_AND:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LOGICAL_AND;
- break;
- case tflite::BuiltinOperator_LOGICAL_OR:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LOGICAL_OR;
- break;
- case tflite::BuiltinOperator_LOGICAL_NOT:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_LOGICAL_NOT;
- break;
- case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_SQUARED_DIFFERENCE_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(),
- static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- case tflite::BuiltinOperator_MAXIMUM:
- nn_op_type = ANEURALNETWORKS_MAXIMUM;
- break;
- case tflite::BuiltinOperator_MINIMUM:
- nn_op_type = ANEURALNETWORKS_MINIMUM;
- break;
- case tflite::BuiltinOperator_ABS:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_ABS;
- break;
- case tflite::BuiltinOperator_ONE_HOT:
- add_one_hot_params(node.builtin_data);
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_ONE_HOT_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue; // _EX operator should use `continue` to skip addOperanation.
- case tflite::BuiltinOperator_SIN:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_SIN;
- break;
- case tflite::BuiltinOperator_SHAPE:
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_SHAPE_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue; // _EX operator should use `continue` to skip addOperanation.
- case tflite::BuiltinOperator_REDUCE_PROD:
- add_reducer_params(node.builtin_data);
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_REDUCE_PROD;
- break;
- case tflite::BuiltinOperator_EXPAND_DIMS:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_EXPAND_DIMS;
- break;
- case tflite::BuiltinOperator_POW:
- if (!(subgraph->tensor(node.inputs->data[0])->type == kTfLiteFloat32 &&
- subgraph->tensor(node.inputs->data[1])->type == kTfLiteFloat32)) {
- logError("NNAPI delegate for Pow supports only float32.", builtin);
- return kTfLiteError;
- }
- nn_op_type = ANEURALNETWORKS_POW;
- break;
- case tflite::BuiltinOperator_SELECT:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_SELECT;
- break;
- case tflite::BuiltinOperator_ZEROS_LIKE:
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_ZEROS_LIKE_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue; // _EX operator should use `continue` to skip addOperanation.
- case tflite::BuiltinOperator_TILE:
- nnapi_version = 12; // require NNAPI 1.2
- nn_op_type = ANEURALNETWORKS_TILE;
- break;
- case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
- case tflite::BuiltinOperator_LSH_PROJECTION:
- case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
- case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
- case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
- case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
- case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- //case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
- case tflite::BuiltinOperator_PADV2:
- //case tflite::BuiltinOperator_RESIZE_BILINEAR:
- case tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
- case tflite::BuiltinOperator_CALL:
- case tflite::BuiltinOperator_SKIP_GRAM:
- //case tflite::BuiltinOperator_RELU_N1_TO_1:
- //case tflite::BuiltinOperator_GATHER:
- //case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
- //case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
- //case tflite::BuiltinOperator_TOPK_V2:
- //case tflite::BuiltinOperator_SPLIT:
- //case tflite::BuiltinOperator_STRIDED_SLICE:
- //case tflite::BuiltinOperator_EXP:
- case tflite::BuiltinOperator_LOG_SOFTMAX:
- //case tflite::BuiltinOperator_DEQUANTIZE:
- case tflite::BuiltinOperator_DELEGATE:
- //case tflite::BuiltinOperator_CAST:
- //case tflite::BuiltinOperator_PRELU:
- //case tflite::BuiltinOperator_MAXIMUM:
- //case tflite::BuiltinOperator_MINIMUM:
- //case tflite::BuiltinOperator_ARG_MAX:
- case tflite::BuiltinOperator_ARG_MIN:
- //case tflite::BuiltinOperator_GREATER:
- //case tflite::BuiltinOperator_GREATER_EQUAL:
- //case tflite::BuiltinOperator_LESS:
- //case tflite::BuiltinOperator_LESS_EQUAL:
- //case tflite::BuiltinOperator_NEG:
- //case tflite::BuiltinOperator_SELECT:
- // case tflite::BuiltinOperator_SLICE:
- //case tflite::BuiltinOperator_SIN:
- //case tflite::BuiltinOperator_LOG:
- //case tflite::BuiltinOperator_TRANSPOSE_CONV:
- //case tflite::BuiltinOperator_TILE:
- //case tflite::BuiltinOperator_EXPAND_DIMS:
- case tflite::BuiltinOperator_SPARSE_TO_DENSE:
- //case tflite::BuiltinOperator_EQUAL:
- //case tflite::BuiltinOperator_NOT_EQUAL:
- //case tflite::BuiltinOperator_SUM:
- //case tflite::BuiltinOperator_REDUCE_MAX:
- //case tflite::BuiltinOperator_REDUCE_MIN:
- //case tflite::BuiltinOperator_REDUCE_PROD:
- //case tflite::BuiltinOperator_SQRT:
- //case tflite::BuiltinOperator_RSQRT:
- //case tflite::BuiltinOperator_SHAPE:
- //case tflite::BuiltinOperator_POW:
- case tflite::BuiltinOperator_FAKE_QUANT:
- //case tflite::BuiltinOperator_PACK:
- //case tflite::BuiltinOperator_LOGICAL_OR:
- //case tflite::BuiltinOperator_ONE_HOT:
- //case tflite::BuiltinOperator_LOGICAL_AND:
- //case tflite::BuiltinOperator_LOGICAL_NOT:
- //case tflite::BuiltinOperator_UNPACK:
- case tflite::BuiltinOperator_FLOOR_DIV:
- //case tflite::BuiltinOperator_REDUCE_ANY:
- case tflite::BuiltinOperator_SQUARE:
- //case tflite::BuiltinOperator_ZEROS_LIKE:
- case tflite::BuiltinOperator_FILL:
- case tflite::BuiltinOperator_FLOOR_MOD:
- case tflite::BuiltinOperator_RANGE:
- case tflite::BuiltinOperator_LEAKY_RELU:
- //case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
- case tflite::BuiltinOperator_MIRROR_PAD:
- //case tflite::BuiltinOperator_ABS:
- case tflite::BuiltinOperator_SPLIT_V:
- logError("Op code %d is currently not delegated to NNAPI", builtin);
- return kTfLiteError;
- break;
- case tflite::BuiltinOperator_CUSTOM: {
- std::string custom_name(registration.custom_name);
- if (custom_name.compare("SquaredDifference") == 0) {
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_SQUARED_DIFFERENCE_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(),
- static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- }
- else if (custom_name.compare("MatrixBandPart") == 0) {
- CHECK_NN(ANeuralNetworksModel_addOperationEx(
- nn_model, ANEURALNETWORKS_MATRIX_BAND_PART_EX,
- static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(),
- static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
- continue;
- }
- logError("Custom operations are not supported when using NNAPI.");
- return kTfLiteError;
- break;
- }
- default:
- // Fix to use strict build option
- logError("Op code %d is currently not delegated to NNAPI", builtin);
- return kTfLiteError;
- break;
- }
-
- if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) {
- //logError("Op %d needs NNAPI1.1", builtin);
- //return kTfLiteError;
- }
-
- // Add the operation.
- RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation(
- nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(),
- static_cast<uint32_t>(augmented_outputs.size()),
- reinterpret_cast<uint32_t*>(augmented_outputs.data())));
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus NNAPIDelegate::BuildGraph(::tflite::Subgraph* subgraph) {
- if (nn_model_ && nn_compiled_model_) return model_status_;
-
- // TODO(aselle): This is not correct. need to handle resize invalidation.
- if (!nn_model_) {
- CHECK_NN(ANeuralNetworksModel_create(&nn_model_));
-
- // Find which tensors should be added to NNAPI. TFLite has temporaries
- // and RNN back-edges which are are not valid for NNAPI. We look through all
- // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with
- // kOperandIdNotSet. addTensorOperands will replace those with the
- // corresponding NNAPI operand ids and skip kOperandNotNeeded entries.
- std::vector<int64_t> tensor_id_to_nnapi_id(subgraph->tensors_size(),
- kOperandNotNeeded);
- // Fix to use strict build option
- auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf,
- int count) {
- for (int j = 0; j < count; j++) {
- auto tensor_id = buf[j];
- if (tensor_id != kOptionalTensor) {
- tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet;
- }
- }
- };
- for (size_t i = 0; i < subgraph->nodes_size(); i++) {
- const auto* node_and_registration = subgraph->node_and_registration(i);
- const TfLiteNode& node = node_and_registration->first;
- set_ids_to_not_set(node.inputs->data, node.inputs->size);
- set_ids_to_not_set(node.outputs->data, node.outputs->size);
- }
- set_ids_to_not_set(subgraph->inputs().data(), subgraph->inputs().size());
- set_ids_to_not_set(subgraph->outputs().data(), subgraph->outputs().size());
-
- uint32_t next_id = 0;
- RETURN_ERROR_IF_TFLITE_FAILED(addTensorOperands(
- subgraph, nn_model_, &next_id, &tensor_id_to_nnapi_id));
- RETURN_ERROR_IF_TFLITE_FAILED(
- AddOpsAndParams(subgraph, nn_model_, next_id, &model_states_inputs_,
- &model_states_outputs_, tensor_id_to_nnapi_id));
-
- std::vector<uint32_t> augmented_inputs;
- MapAndAddTensorIds(subgraph->inputs().data(), subgraph->inputs().size(),
- &augmented_inputs, tensor_id_to_nnapi_id);
- augmented_inputs.insert(augmented_inputs.end(),
- model_states_inputs_.begin(),
- model_states_inputs_.end());
- std::vector<uint32_t> augmented_outputs;
- MapAndAddTensorIds(subgraph->outputs().data(), subgraph->outputs().size(),
- &augmented_outputs, tensor_id_to_nnapi_id);
- MapAndAddTensorIds(model_states_outputs_.data(),
- model_states_outputs_.size(), &augmented_outputs,
- tensor_id_to_nnapi_id);
-
- CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
- nn_model_, static_cast<uint32_t>(augmented_inputs.size()),
- reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
- static_cast<uint32_t>(augmented_outputs.size()),
- reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
-
- // TODO Support ANeuralNetworksModel_relaxComputationFloat32toFloat16
- /*if (GetAndroidSdkVersionCached() >= 28) {
- CHECK_NN(ANeuralNetworksModel_relaxComputationFloat32toFloat16(
- nn_model_, subgraph->GetAllowFp16PrecisionForFp32()));
- }*/
- CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
- }
- if (!nn_compiled_model_) {
- CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_));
- CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_));
- }
- return kTfLiteOk;
-}
-
-// Use unordered_map for temporary buffer
-#include <unordered_map>
-
-TfLiteStatus NNAPIDelegate::Invoke(::tflite::Subgraph* subgraph) {
- if (!nn_model_) {
- model_status_ = BuildGraph(subgraph);
- if (model_status_ != kTfLiteOk) {
- logError("Failed to build graph for NNAPI");
- }
- }
- if (model_status_ != kTfLiteOk) {
- return model_status_;
- }
-
- ANeuralNetworksExecution* execution = nullptr;
- CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution));
-
- // Allocate temporary buffer to save casted boolean tensor
- std::unordered_map<size_t, uint8_t*> input_boolean_tensors;
- std::unordered_map<size_t, uint8_t*> output_boolean_tensors;
- for (size_t i = 0; i < subgraph->inputs().size(); i++)
- {
- int input = subgraph->inputs()[i];
- TfLiteTensor* tensor = subgraph->tensor(input);
- if (tensor->type == kTfLiteBool)
- {
- size_t elements = tensor->bytes / sizeof(bool);
- uint8_t* temp_tensor = new uint8_t[tensor->bytes / sizeof(bool)];
- input_boolean_tensors[i] = temp_tensor;
- for (size_t idx = 0; idx < elements; idx++)
- {
- temp_tensor[idx] = (tensor->data.b[idx] ? 0x00 : 0xff);
- }
- }
- }
- for (size_t i = 0; i < subgraph->outputs().size(); i++)
- {
- int output = subgraph->outputs()[i];
- TfLiteTensor* tensor = subgraph->tensor(output);
- if (tensor->type == kTfLiteBool)
- {
- uint8_t* temp_tensor = new uint8_t[tensor->bytes / sizeof(bool)];
- output_boolean_tensors[i] = temp_tensor;
- }
- }
-
- // Currently perform deep copy of input buffer
- for (size_t i = 0; i < subgraph->inputs().size(); i++) {
- int input = subgraph->inputs()[i];
- // TODO(aselle): Is this what we want or do we want input instead?
- // TODO(aselle): This should be called setInputValue maybe to be cons.
- TfLiteTensor* tensor = subgraph->tensor(input);
- // Workaround to pass bool type under NNAPI
- // ANEURALNETWORKS_TENSOR_BOOL8 tensor element size is 8 bits
- if (tensor->type == kTfLiteBool)
- {
- CHECK_NN(ANeuralNetworksExecution_setInput(
- execution, i, nullptr, input_boolean_tensors[i], tensor->bytes * sizeof(uint8_t) / sizeof(bool)));
- }
- else
- {
- CHECK_NN(ANeuralNetworksExecution_setInput(
- execution, i, nullptr, tensor->data.raw, tensor->bytes));
- }
- }
-
- // Tell nn api where to place final data.
- for (size_t i = 0; i < subgraph->outputs().size(); i++) {
- int output = subgraph->outputs()[i];
- TfLiteTensor* tensor = subgraph->tensor(output);
-
- // Workaround to pass bool type under NNAPI
- // ANEURALNETWORKS_TENSOR_BOOL8 tensor element size is 8 bits
- if (tensor->type == kTfLiteBool)
- {
- CHECK_NN(ANeuralNetworksExecution_setOutput(
- execution, i, nullptr, output_boolean_tensors[i], tensor->bytes * sizeof(uint8_t) / sizeof(bool)));
- }
- else
- {
- CHECK_NN(ANeuralNetworksExecution_setOutput(
- execution, i, nullptr, tensor->data.raw, tensor->bytes));
- }
- }
-
- // The state_out of previous invocation need to be mapped to state_in of
- // current invocation.
- for (size_t i = 0; i < model_states_outputs_.size(); i++) {
- int state_tensor_idx = model_states_outputs_[i];
- TfLiteTensor* tensor = subgraph->tensor(state_tensor_idx);
- // Here we are using a deep copy for state_in tensors so that we are not
- // reading and writing into the same buffer during a invocation.
- // TODO(miaowang): using double shared buffer to minimize the copies.
- CHECK_NN(ANeuralNetworksExecution_setInput(
- execution, i + subgraph->inputs().size(), nullptr, tensor->data.raw,
- tensor->bytes));
- // Tell NNAPI where to output the state_out.
- CHECK_NN(ANeuralNetworksExecution_setOutput(
- execution, i + subgraph->outputs().size(), nullptr, tensor->data.raw,
- tensor->bytes));
- }
-
- // Currently use blocking compute.
- ANeuralNetworksEvent* event = nullptr;
- CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event));
- CHECK_NN(ANeuralNetworksEvent_wait(event));
- ANeuralNetworksEvent_free(event);
- ANeuralNetworksExecution_free(execution);
-
- // Tell nn api where to place final data.
- for (size_t i = 0; i < subgraph->inputs().size(); i++) {
- int input = subgraph->inputs()[i];
- TfLiteTensor* tensor = subgraph->tensor(input);
-
- if (tensor->type == kTfLiteBool)
- {
- uint8_t* temp_tensor = input_boolean_tensors[i];
- input_boolean_tensors[i] = nullptr;
- delete temp_tensor;
- }
- }
- for (size_t i = 0; i < subgraph->outputs().size(); i++) {
- int output = subgraph->outputs()[i];
- TfLiteTensor* tensor = subgraph->tensor(output);
-
- if (tensor->type == kTfLiteBool)
- {
- uint8_t* temp_tensor = output_boolean_tensors[i];
- size_t elements = tensor->bytes / sizeof(bool);
- for (size_t idx = 0; idx < elements; idx++)
- {
- tensor->data.b[idx] = ((temp_tensor[idx] == 0x00) ? false : true);
- }
- output_boolean_tensors[i] = nullptr;
- delete temp_tensor;
- }
- }
-
-#if 0
- printf("From the NN API:\n");
- TfLiteTensor* tensor = subgraph->tensor(subgraph->outputs()[0]);
- if (float* data =
- subgraph->typed_tensor<float>(subgraph->outputs()[0])) {
- size_t num = tensor->bytes / sizeof(float);
- for (float* p = data; p < data + num; p++) {
- printf(" %f", *p);
- }
- printf("\n");
- }
-#endif
-
- return kTfLiteOk;
-}
-
-bool NNAPIDelegate::IsSupported() { return nnfw::NNAPIExists(); }
-
-} // namespace tflite
-} // namespace nnfw
-
-// clang-format on
+++ /dev/null
-// This file is included from AddOpsAndParams defined in nnapi_delegate.cc
-// and contains lambda for extened implementation to original Tensorflow Lite.
- auto add_scalar_bool8 = [&nn_model, &augmented_inputs,
- &next_id](bool value) {
- // Fix to use strict build option
- int8_t casted_value = (value ? 1 : 0);
- ANeuralNetworksOperandType operand_type{}; operand_type.type = ANEURALNETWORKS_BOOL;
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &casted_value,
- sizeof(int8_t)))
- augmented_inputs.push_back(next_id++);
- };
-
- auto add_resize_bilinear_params = [&add_scalar_int32, &subgraph, &augmented_inputs](void* data) {
- auto builtin = reinterpret_cast<TfLiteResizeBilinearParams*>(data);
- if (builtin->align_corners) {
- FATAL("Resize bilinear does not support align corners in NNAPI");
- }
-
- TfLiteTensor* tensor = subgraph->tensor(augmented_inputs.back());
- assert(tensor->type == kTfLiteInt32);
- assert(tensor->bytes == sizeof(int)*2);
- augmented_inputs.pop_back();
-
- int height = ((int*)(tensor->data.raw))[1];
- int width = ((int*)(tensor->data.raw))[0];
- add_scalar_int32(height);
- add_scalar_int32(width);
- };
-
- auto add_transpose_conv_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteTransposeConvParams*>(data);
- add_scalar_int32(builtin->padding);
- add_scalar_int32(builtin->stride_width);
- add_scalar_int32(builtin->stride_height);
- };
-
- auto add_lrn_params = [&add_scalar_int32,
- &add_scalar_float32](void* data) {
- auto builtin = reinterpret_cast<TfLiteLocalResponseNormParams*>(data);
- add_scalar_int32(builtin->radius);
- add_scalar_float32(builtin->bias);
- add_scalar_float32(builtin->alpha);
- add_scalar_float32(builtin->beta);
- };
-
- auto add_strided_slice_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteStridedSliceParams*>(data);
- add_scalar_int32(builtin->begin_mask);
- add_scalar_int32(builtin->end_mask);
- // ellipsis_mask and new_axis_mask are not supported on nn runtime
- // cf) tflite interpreter supports both operations
- if (builtin->ellipsis_mask) {
- FATAL("STRIDE_SLICE does not support ellipsis_mask in NNAPI");
- }
- if (builtin->new_axis_mask) {
- FATAL("STRIDE_SLICE does not support new_axis_mask in NNAPI");
- }
- add_scalar_int32(builtin->shrink_axis_mask);
- };
-
- auto add_gather_params = [&add_scalar_int32, &augmented_inputs](void* data) {
- auto builtin = reinterpret_cast<TfLiteGatherParams*>(data);
- if (builtin->axis != 0) {
- FATAL("GATHER does not support axis>0 in NNAPI");
- }
-
- auto indices_index = augmented_inputs.back();
- augmented_inputs.pop_back();
- add_scalar_int32(builtin->axis);
- augmented_inputs.push_back(indices_index);
- };
-
- auto add_pack_ex_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLitePackParams*>(data);
- add_scalar_int32(builtin->values_count);
- add_scalar_int32(builtin->axis);
- };
-
- auto add_unpack_ex_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteUnpackParams*>(data);
- add_scalar_int32(builtin->num);
- add_scalar_int32(builtin->axis);
- };
-
- auto check_batch_to_space_params = [subgraph, &node, &augmented_inputs]() {
-
- //If there are 3 inputs, check if crops is having default values {0, 0, 0, 0}
- //Else unsupported by NNAPI
-
- if(augmented_inputs.size() == 3)
- {
- const uint32_t crops_buffer_index = node.inputs->data[2];
- const TfLiteTensor* crops = subgraph->tensor(crops_buffer_index);
- const int *crops_value = crops->data.i32;
-
- //Check if crops is having default values {0, 0, 0, 0}
- if(crops_value[0] != 0 || crops_value[1] != 0 || crops_value[2] != 0 || crops_value[3] != 0)
- {
- FATAL("BATCH_TO_SPACE_ND does not support Explicit crops in NNAPI");
- }
- else
- {
- //Restrict crops input and pass only other two inputs
- augmented_inputs.pop_back();
- }
- }
- };
-
- auto add_split_params = [&add_scalar_int32, &augmented_inputs](void* data) {
- // swap 1st and 2nd operand order
- auto input_tensor = augmented_inputs[1];
- auto axis = augmented_inputs[0];
- augmented_inputs[0] = input_tensor;
- augmented_inputs[1] = axis;
-
- auto builtin = reinterpret_cast<TfLiteSplitParams*>(data);
- add_scalar_int32(builtin->num_splits);
- };
-
- auto check_arg_max_input = [&subgraph, &augmented_inputs](void *data) {
- auto params = reinterpret_cast<TfLiteArgMaxParams*>(data);
- if (params->output_type != kTfLiteInt32)
- {
- FATAL("Cannot handle output type in NNAPI");
- }
-
- TfLiteTensor* axis_tensor = subgraph->tensor(augmented_inputs.back());
- assert(axis_tensor->type == kTfLiteInt32);
-
- int64_t count = 1;
- for (int i = 0; i < axis_tensor->dims->size; ++i) {
- count *= axis_tensor->dims->data[i];
- }
- assert(count == 1);
- };
-
- auto add_reducer_params = [&add_scalar_bool8](void* data) {
- auto builtin = reinterpret_cast<TfLiteReducerParams*>(data);
- if (builtin == nullptr)
- {
- add_scalar_bool8(0);
- }
- else
- {
- add_scalar_bool8(builtin->keep_dims);
- }
- };
-
- auto add_one_hot_params = [&add_scalar_int32](void* data) {
- const auto* builtin = reinterpret_cast<TfLiteOneHotParams*>(data);
- add_scalar_int32(builtin->axis);
- };
+++ /dev/null
-# We may need to support multiple tensorflow version
-# Example)
-# For ubuntu: tensorflow lite v1.13.1
-# For tizen: tensorflow lite v1.9
-set(SUPPORT_TFLITE_VERSION "1.13.1" CACHE STRING "Supporting TensorFlow lite version")
-
-add_subdirectories()
#include "tflite/RandomTestRunner.h"
#include "tflite/Diff.h"
#include "tflite/TensorLogger.h"
-#include "tflite/ext/nnapi_delegate.h"
#include <misc/tensor/IndexIterator.h>
#include <misc/tensor/Object.h>
std::cout << "[NNAPI TEST] Run T/F Lite Interpreter without NNAPI" << std::endl;
_tfl_interp->Invoke();
- nnfw::tflite::NNAPIDelegate d;
-
for (size_t i = 1; i <= running_count; ++i)
{
resetter.run(*(_nnapi.get()));
std::cout << "[NNAPI TEST #" << i << "] Run T/F Lite Interpreter with NNAPI" << std::endl;
- char *env = getenv("UPSTREAM_DELEGATE");
-
- if (env && !std::string(env).compare("1"))
- {
- _nnapi->Invoke();
- }
- else
+ if (_nnapi->Invoke() != kTfLiteOk)
{
- // WARNING
- // primary_subgraph: Experimental interface. Return 1st sugbraph
- // Invoke() will call BuildGraph() internally
- if (d.Invoke(&_nnapi.get()->primary_subgraph()))
- {
- throw std::runtime_error{"Failed to BuildGraph"};
- }
+ throw std::runtime_error{"Failed to Run T/F Lite Interpreter with NNAPI"};
}
// Compare OFM
#include "tflite/interp/FlatBufferBuilder.h"
-#include "tflite/ext/kernels/register.h"
+#include <tensorflow/lite/kernels/register.h>
namespace nnfw
{
{
std::unique_ptr<::tflite::Interpreter> interpreter;
- nnfw::tflite::BuiltinOpResolver resolver;
+ ::tflite::ops::builtin::BuiltinOpResolver resolver;
::tflite::InterpreterBuilder builder(_model, resolver);
*/
NNFW_STATUS nnfw_output_tensorindex(nnfw_session *session, const char *tensorname, uint32_t *index);
+/**
+ * @brief Set the backend for each operation in the session
+ *
+ * This function assigns backends (acl_cl, acl_neon, cpu) to each operation in the session.
+ * If successful,the function returns @c NNFW_STATUS_NO_ERROR. Otherwise, the function returns
+ * @c NNFW_STATUS_ERROR.
+ *
+ * @note The argument specifying backends must be in the format
+ * "OP_BACKEND_MAP=\"0=acl_cl;1=cpu;2=acl_cl\"".
+ *
+ * @param[in] session the session object
+ * @param[in] backend_settings String containing backend assignments indexed by operation sequence
+ * @return @c NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_set_backends_per_operation(nnfw_session *session, const char *backend_settings);
+
+/*
+ * Prepare session to be ready for inference
+ * This phase may finalize model compilation, scheduling, and additional settings.
+ *
+ * @param session the session to be prepared
+ * @return NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_prepare_pipeline(nnfw_session *session, const char *map_file_path = nullptr);
+
+/**
+ * @brief Set input buffer
+ *
+ * This function must be called after {@link nnfw_prepare_pipeline}, \p inputs given to this
+ * function can be reused for many inferences. \p lengths must be greater or equal than the operand
+ * requires. if you give empty \p inputs to this function, then this function will join all threads.
+ *
+ * @param[in] session Session to the input is to be set
+ * @param[in] inputs Raw buffers for input, it must be \p std::vector<void *> type pointer for
+ * multiple input model
+ * @param[in] lengths Size of bytes of input buffers, it must be \p std::vector<uint32_t> type
+ * pointer for multiple input model
+ *
+ * @return @c NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_push_pipeline_input(nnfw_session *session, void *inputs, void *lengths);
+
+/**
+ * @brief Get last outputs of partitioned model in session
+ *
+ * This function must be called after {@link nnfw_prepare_pipeline}, \p outputs given to this
+ * function must be cleared for memory management.
+ *
+ * @param[in] session Session from last outputs is to be extracted
+ * @param[out] outputs Raw buffer for outputs, it must be \p std::vector<void *> type pointer for
+ * multiple output model
+ *
+ * @return @c NNFW_STATUS_NO_ERROR if successful
+ */
+NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs);
+
#endif // __NNFW_EXPERIMENTAL_H__
* NNFW_VERSION is a uint32 value representing nnfw runtime version
* in 0xMMmmmmPP, where MM = major, mmmm = minor, PP = patch
*/
-#define NNFW_VERSION 0x01000f00
+#define NNFW_VERSION 0x01001100
#endif // __NNFW_VERSION_H__
NNFW_RETURN_ERROR_IF_NULL(session);
return session->output_tensorindex(tensorname, index);
}
+
+NNFW_STATUS nnfw_set_backends_per_operation(nnfw_session *session, const char *backend_settings)
+{
+ NNFW_RETURN_ERROR_IF_NULL(session);
+ return session->set_backends_per_operation(backend_settings);
+}
+
+NNFW_STATUS nnfw_prepare_pipeline(nnfw_session *session, const char *map_file_path)
+{
+ NNFW_RETURN_ERROR_IF_NULL(session);
+ return session->prepare_pipeline(map_file_path);
+}
+
+NNFW_STATUS nnfw_push_pipeline_input(nnfw_session *session, void *inputs, void *lengths)
+{
+ NNFW_RETURN_ERROR_IF_NULL(session);
+ return session->push_pipeline_input((std::vector<void *> *)inputs,
+ (std::vector<uint32_t> *)lengths);
+}
+
+NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs)
+{
+ NNFW_RETURN_ERROR_IF_NULL(session);
+ return session->pop_pipeline_output((std::vector<void *> *)outputs);
+}
} // namespace
nnfw_session::nnfw_session()
- : _subgraphs{nullptr}, _execution{nullptr},
+ : _subgraphs{nullptr}, _compiler{nullptr}, _execution{nullptr},
_kernel_registry{std::make_shared<onert::api::CustomKernelRegistry>()}, _tracing_ctx{nullptr}
{
// DO NOTHING
std::string manifest_file_name = package_path + "/metadata/MANIFEST";
std::ifstream mfs(manifest_file_name);
+ _package_file_path = package_path;
// extract the filename of the first(index 0) model
// e.g. In MANIFEST file, { "models" : [ "firstmodel.tflite", "2nd.tflite" ] }
Json::Value root;
return NNFW_STATUS_NO_ERROR;
}
+NNFW_STATUS nnfw_session::prepare_pipeline(const char *map_file_path)
+{
+ // NOTE. If users want to run prepare_pipeline() more than one time, this could be removed.
+ if (!isStateModelLoaded())
+ {
+ std::cerr << "Error during model prepare pipeline : ";
+ if (isStateInitialized())
+ {
+ std::cerr << "prepare_pipeline should be run once";
+ }
+ else
+ {
+ std::cerr << "invalid state";
+ }
+ std::cerr << std::endl;
+ return NNFW_STATUS_INVALID_STATE;
+ }
+
+ try
+ {
+ _subgraphs.reset();
+ std::vector<std::shared_ptr<onert::exec::ExecutorMap>> executor_maps =
+ _compiler->compile(_package_file_path.c_str(), map_file_path);
+
+ for (auto it = executor_maps.begin(); it != executor_maps.end(); ++it)
+ {
+ _executions.push_back(std::make_shared<onert::exec::Execution>(*it));
+ }
+ make_dependency();
+ _threads.resize(_executions.size());
+ for (uint32_t i = 0; i < _threads.size(); i++)
+ {
+ _threads[i] = std::thread(&onert::exec::Execution::runInference, _executions[i].get());
+ }
+ }
+ catch (const std::exception &e)
+ {
+ std::cerr << "Error during model prepare : " << e.what() << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
+ _state = State::PREPARED;
+ return NNFW_STATUS_NO_ERROR;
+}
+
NNFW_STATUS nnfw_session::run()
{
if (!isStatePreparedOrFinishedRun())
return NNFW_STATUS_INVALID_STATE;
}
+ if (!_executions.empty())
+ {
+ std::cerr << "Error during nnfw_session::run : not supported for pipeline run" << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
try
{
_execution->execute();
return NNFW_STATUS_INVALID_STATE;
}
+ if (!_executions.empty())
+ {
+ std::cerr << "Error during nnfw_session::run_async : not supported for pipeline run"
+ << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
_execution->startExecute();
_state = State::RUNNING;
return NNFW_STATUS_ERROR;
}
+ if (!_executions.empty())
+ {
+ std::cerr << "Error during nnfw_session::await : not supported for pipeline run" << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
_execution->waitFinish();
_state = State::FINISHED_RUN;
return NNFW_STATUS_ERROR;
}
+ if (!_executions.empty())
+ {
+ std::cerr << "Error during nnfw_session::set_input : not supported for pipeline run"
+ << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
try
{
_execution->setInput(onert::ir::IOIndex(index), buffer, length);
return NNFW_STATUS_ERROR;
}
+ if (!_executions.empty())
+ {
+ std::cerr << "Error during nnfw_session::set_output : not supported for pipeline run"
+ << std::endl;
+ return NNFW_STATUS_ERROR;
+ }
+
try
{
_execution->setOutput(onert::ir::IOIndex(index), buffer, length);
std::cerr << "Error during nnfw_session::set_input_layout, not supported layout" << std::endl;
return NNFW_STATUS_ERROR;
}
- _execution->setInputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ if (_execution)
+ {
+ _execution->setInputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ }
+ else
+ {
+ _executions.at(0)->setInputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ }
}
catch (const std::exception &e)
{
<< std::endl;
return NNFW_STATUS_ERROR;
}
- _execution->setOutputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ if (_execution)
+ {
+ _execution->setOutputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ }
+ else
+ {
+ _executions.at(_executions.size() - 1)
+ ->setOutputLayout(onert::ir::IOIndex(index), convertLayout(layout));
+ }
}
catch (const std::exception &e)
{
}
else // when called after nnfw_session::prepare()
{
- _execution->changeInputShape(onert::ir::IOIndex(index), new_shape);
+ if (_execution)
+ {
+ _execution->changeInputShape(onert::ir::IOIndex(index), new_shape);
+ }
+ else
+ {
+ _executions.at(0)->changeInputShape(onert::ir::IOIndex(index), new_shape);
+ }
}
return NNFW_STATUS_NO_ERROR;
auto opidx = primary_subgraph()->getInputs().at(index);
auto shape = primary_subgraph()->operands().at(opidx).shape();
if (isStatePreparedOrFinishedRun())
- shape = _execution->getInputShape(onert::ir::IOIndex{index});
+ {
+ if (_execution)
+ {
+ shape = _execution->getInputShape(onert::ir::IOIndex{index});
+ }
+ else
+ {
+ shape = _executions.at(0)->getInputShape(onert::ir::IOIndex{index});
+ }
+ }
+
ti->rank = shape.rank();
for (int j = 0; j < ti->rank; ++j)
{
auto shape = primary_subgraph()->operands().at(opidx).shape();
// If it is called after `nnfw_run` then get the shape from Execution, not from the graph
if (isStateFinishedRun())
- shape = _execution->getOutputShape(onert::ir::IOIndex{index});
+ {
+ if (_execution)
+ {
+ shape = _execution->getOutputShape(onert::ir::IOIndex{index});
+ }
+ else
+ {
+ shape = _executions.at(_executions.size() - 1)->getOutputShape(onert::ir::IOIndex{index});
+ }
+ }
ti->rank = shape.rank();
for (int j = 0; j < ti->rank; ++j)
{
return NNFW_STATUS_NO_ERROR;
}
+
+void nnfw_session::make_dependency()
+{
+ for (uint32_t out_exe = 0; out_exe < _executions.size(); out_exe++)
+ {
+ auto out_graph = _executions[out_exe]->primary_subgraph();
+ for (uint32_t in_exe = 0; in_exe < _executions.size(); in_exe++)
+ {
+ if (out_exe == in_exe)
+ continue;
+ auto in_graph = _executions[in_exe]->primary_subgraph();
+ for (auto out = out_graph._name_to_output_begin(); out != out_graph._name_to_output_end();
+ out++)
+ {
+ auto out_opidx = out_graph.getOutputs().at(out->second);
+ auto out_shape = out_graph.operands().at(out_opidx).shape();
+ for (auto in = in_graph._name_to_input_begin(); in != in_graph._name_to_input_end(); in++)
+ {
+ if (out->first != in->first)
+ continue;
+
+ auto in_opidx = in_graph.getInputs().at(in->second);
+ auto in_shape = in_graph.operands().at(in_opidx).shape();
+ if (out_shape.rank() != in_shape.rank())
+ continue;
+
+ bool is_same = true;
+ for (int32_t i = 0; i < out_shape.rank(); i++)
+ {
+ if (out_shape.dim(i) != in_shape.dim(i))
+ {
+ is_same = false;
+ break;
+ }
+ }
+
+ if (is_same)
+ _executions[out_exe]->pushNextExe(_executions[in_exe], out->second, in->second);
+ }
+ }
+ }
+ }
+}
+
+NNFW_STATUS nnfw_session::push_pipeline_input(std::vector<void *> *inputs,
+ std::vector<uint32_t> *lengths)
+{
+ static uint32_t count = 0;
+ if (inputs->empty())
+ {
+ _executions[0]->setFinish();
+ for (uint32_t i = 0; i < _threads.size(); i++)
+ {
+ _threads[i].join();
+ }
+ return NNFW_STATUS_NO_ERROR;
+ }
+ _executions[0]->asyncIoDescSemWait();
+ _executions[0]->createNewAsyncDesc(count++);
+ for (uint32_t i = 0; i < inputs->size(); i++)
+ {
+ _executions[0]->executeAsyncInput(onert::ir::IOIndex(i), inputs->at(i), lengths->at(i));
+ }
+ _executions[0]->asyncIoDescSemPost();
+ return NNFW_STATUS_NO_ERROR;
+}
+
+NNFW_STATUS nnfw_session::pop_pipeline_output(std::vector<void *> *outputs)
+{
+ auto results = _executions[_executions.size() - 1]->getAsyncResults();
+ while (results->empty())
+ {
+ if (_executions[_executions.size() - 1]->stopWait())
+ return NNFW_STATUS_ERROR;
+ }
+
+ auto result = results->front();
+ results->pop_front();
+ for (uint32_t i = 0; i < result.size(); i++)
+ outputs->push_back(result[i]);
+ return NNFW_STATUS_NO_ERROR;
+}
+
NNFW_STATUS nnfw_session::register_custom_operation(const std::string &id,
nnfw_custom_eval eval_func)
{
{
if (_subgraphs)
{
- assert(!_execution);
+ assert(!_execution && _executions.empty());
return _subgraphs->primary().get();
}
else
{
- assert(_execution);
+ assert(_execution || !_executions.empty());
// TODO Remove const_cast
// We assumed the graph will not change after compilation, but shape could change
+ if (!_executions.empty())
+ {
+ return &_executions[0]->primary_parentgraph();
+ }
+
return &_execution->primary_subgraph();
}
}
{
assert(!_subgraphs);
assert(!_compiler);
- assert(!_execution);
+ assert(!_execution && _executions.empty());
return true;
}
else
{
assert(_subgraphs);
assert(_compiler);
- assert(!_execution);
+ assert(!_execution && _executions.empty());
return true;
}
else
{
assert(!_subgraphs);
assert(_compiler);
- assert(_execution);
+ assert(_execution || !_executions.empty());
return true;
}
else
{
assert(!_subgraphs);
assert(_compiler);
- assert(_execution);
+ assert(_execution || !_executions.empty());
return true;
}
return false;
{
assert(!_subgraphs);
assert(_compiler);
- assert(_execution);
+ assert(_execution || !_executions.empty());
return true;
}
else
{
return getTensorIndexImpl(*primary_subgraph(), tensorname, index, false);
}
+
+NNFW_STATUS nnfw_session::set_backends_per_operation(const char *backend_settings)
+{
+ if (backend_settings == NULL)
+ {
+ return NNFW_STATUS_ERROR;
+ }
+ _compiler->set_backend_from_str(backend_settings);
+ return NNFW_STATUS_NO_ERROR;
+}
#include <string>
#include <memory>
+#include <thread>
+#include <vector>
namespace onert
{
NNFW_STATUS load_model_from_nnpackage(const char *package_file_path);
NNFW_STATUS prepare();
+ NNFW_STATUS prepare_pipeline(const char *map_file_path);
NNFW_STATUS run();
NNFW_STATUS run_async();
NNFW_STATUS set_available_backends(const char *backends);
NNFW_STATUS set_op_backend(const char *op, const char *backend);
+ // accessor
+ std::vector<std::shared_ptr<onert::exec::Execution>> *get_executions() { return &_executions; }
+
//
// Internal-only API
//
//
// Experimental API
//
+ void make_dependency();
+ NNFW_STATUS push_pipeline_input(std::vector<void *> *inputs, std::vector<uint32_t> *lengths);
+ NNFW_STATUS pop_pipeline_output(std::vector<void *> *outputs);
NNFW_STATUS register_custom_operation(const std::string &id, nnfw_custom_eval eval_func);
NNFW_STATUS input_tensorindex(const char *tensorname, uint32_t *index);
NNFW_STATUS output_tensorindex(const char *tensorname, uint32_t *index);
+ NNFW_STATUS set_backends_per_operation(const char *backend_settings);
private:
const onert::ir::Graph *primary_subgraph();
std::unique_ptr<onert::compiler::Compiler> _compiler;
std::unique_ptr<onert::exec::Execution> _execution;
std::shared_ptr<onert::api::CustomKernelRegistry> _kernel_registry;
+ std::vector<std::thread> _threads;
+ std::vector<std::shared_ptr<onert::exec::Execution>> _executions;
+ std::string _package_file_path;
std::unique_ptr<onert::util::TracingCtx> _tracing_ctx;
};
add_subdirectory(acl_neon)
add_subdirectory(acl_common)
add_subdirectory(ruy)
+add_subdirectory(gpu_cl)
add_subdirectory(xnnpack)
{
switch (type_ir)
{
+ case ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV:
+ return ops::ElementwiseBinaryType::kFloorDiv;
case ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND:
return ops::ElementwiseBinaryType::kLogicalAnd;
case ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR:
#include "OperationUtils.h"
+#include <cker/operation/FloorDiv.h>
#include <cker/operation/LogicalAnd.h>
#include <cker/operation/LogicalOr.h>
#include <cker/operation/MaxMin.h>
namespace
{
template <typename T>
+void FloorDivGeneric(const IPortableTensor *lhs, const IPortableTensor *rhs,
+ IPortableTensor *output)
+{
+ if (!HaveSameShapes(lhs, rhs))
+ {
+ nnfw::cker::FloorDivBroadcast<T>(getShape(lhs), getBuffer<T>(lhs), getShape(rhs),
+ getBuffer<T>(rhs), getShape(output), getBuffer<T>(output));
+ }
+ else
+ {
+ nnfw::cker::FloorDivElementwise<T>(getShape(lhs), getBuffer<T>(lhs), getBuffer<T>(rhs),
+ getBuffer<T>(output));
+ }
+}
+
+template <typename T>
void logicalAndGeneric(const IPortableTensor *lhs, const IPortableTensor *rhs,
IPortableTensor *output)
{
switch (op_type)
{
+ case ElementwiseBinaryType::kFloorDiv:
+ if (_lhs->data_type() == OperandType::FLOAT32)
+ {
+ _kernel = FloorDivGeneric<float>;
+ }
+ else if (_lhs->data_type() == OperandType::INT32)
+ {
+ _kernel = FloorDivGeneric<int32_t>;
+ }
+ else
+ {
+ throw std::runtime_error{"Max: unsupported data type"};
+ }
+ break;
case ElementwiseBinaryType::kLogicalAnd:
if ((_lhs->data_type() == OperandType::BOOL8) && (_rhs->data_type() == OperandType::BOOL8))
{
enum class ElementwiseBinaryType
{
+ kFloorDiv,
kLogicalAnd,
kLogicalOr,
kMax,
op_params.filter_height = kernelHeight; \
op_params.filter_width = kernelWidth; \
op_params.padding_values.height = (int8_t)paddingTop; \
- op_params.padding_values.width = (int8_t)paddingLeft;
+ op_params.padding_values.width = (int8_t)paddingLeft; \
+ op_params.float_activation_min = 0; \
+ op_params.float_activation_max = 0; \
+ op_params.quantized_activation_min = 0; \
+ op_params.quantized_activation_max = 0;
void PoolLayer::configure(const IPortableTensor *input, const uint32_t paddingLeft, const uint32_t,
const uint32_t paddingTop, const uint32_t, const uint32_t strideWidth,
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_BACKEND_H__
+#define __ONERT_BACKEND_GPU_CL_BACKEND_H__
+
+#include <backend/Backend.h>
+#include <memory>
+
+#include "BackendContext.h"
+#include "Config.h"
+#include "ClTensorRegistry.h"
+#include "KernelGenerator.h"
+#include "TensorManager.h"
+#include "TensorBuilder.h"
+
+#include "open_cl/Environment.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class Backend : public ::onert::backend::Backend
+{
+public:
+ Backend() : _config{std::make_shared<Config>()} {}
+
+ std::shared_ptr<IConfig> config() const override { return _config; }
+
+ std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&data) const override
+ {
+ const auto &graph = *data.graph;
+ const auto &operands = data.graph->operands();
+ auto context = std::make_unique<gpu_cl::BackendContext>(this, std::move(data));
+
+ auto environment = std::make_shared<Environment>();
+ if (!CreateEnvironment(environment.get()).ok())
+ {
+ return nullptr;
+ }
+ auto tm = createTensorManager(&environment->context());
+
+ auto tr = std::make_shared<ClTensorRegistry<TensorManager>>(tm);
+
+ InferenceContext::CreateInferenceInfo create_info;
+ create_info.precision = CalculationsPrecision::F32;
+ create_info.storage_type =
+ GetStorageTypeWithMinimalMemoryConsumption(environment->device().GetInfo());
+ create_info.hints.Add(ModelHints::kFastestInference);
+
+ auto cc = std::make_shared<CreationContext>();
+ cc->device = environment->GetDevicePtr();
+ cc->context = &environment->context();
+ cc->queue = environment->queue();
+ cc->cache = environment->program_cache();
+
+ auto tb = std::make_shared<TensorBuilder>(operands, tm, create_info, environment);
+ context->tensor_registry = tr;
+ context->tensor_builder = tb;
+
+ context->kernel_gen = std::make_shared<KernelGenerator>(graph, tb, tr, cc);
+ context->constant_initializer = std::make_shared<ConstantInitializer>(operands, tr);
+ return context;
+ }
+
+private:
+ std::shared_ptr<IConfig> _config;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_BACKEND_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BackendContext.h"
+
+#include "ConstantInitializer.h"
+#include "TensorBuilder.h"
+#include "KernelGenerator.h"
+
+#include "util/logging.h"
+#include "ir/Index.h"
+#include "ir/Operations.h"
+#include "ir/OperandIndexMap.h"
+#include "ir/OperandIndexSequence.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+void BackendContext::initConsts()
+{
+ _data.graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) {
+ constant_initializer->setLayout(graph()->layout());
+ op.accept(*constant_initializer);
+ });
+ _data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ if (_data.external_operands.contains(ind) || !operand.isConstant())
+ return;
+ const auto &obj = graph()->operands().at(ind);
+ if (obj.isConstant() && !constant_initializer->exist(ind))
+ {
+ constant_initializer->registerDefaultInitializer(ind, obj);
+ }
+ });
+
+ constant_initializer->run();
+}
+
+void BackendContext::planTensors()
+{
+ ir::OperandIndexMap<uint32_t> uses_map;
+ ir::OperandIndexMap<uint32_t> def_map;
+ ir::OperandIndexSequence constants;
+
+ // Prepare scanning
+ _data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
+ if (_data.external_operands.contains(ind))
+ return;
+ uses_map[ind] = obj.getUses().size();
+ def_map[ind] = obj.getDef().valid() ? 1 : 0;
+
+ if (obj.isConstant())
+ constants.append(ind);
+
+ if (!tensor_builder->isRegistered(ind))
+ {
+ // These tensors do not exist in any operation (No use and def)
+ const auto info = obj.info();
+ const auto layout = _data.operand_layouts.at(ind);
+ // TODO Change tensor info to have permuted shape
+ tensor_builder->registerTensorInfo(ind, info, layout);
+ }
+ });
+
+ // Start scanning to do notify{First|Last}Use for each tensor
+
+ // If a tensor is a constant, increase the use of the tensor and allocate it first.
+ // Increasing use count here makes the tensor never be deallocated, i.e it they will be
+ // deallocated last.
+ VERBOSE(planTensors) << "TENSORS as CONSTANT" << std::endl;
+ for (const auto &ind : constants)
+ {
+ uses_map[ind]++;
+ tensor_builder->notifyFirstUse(ind);
+ }
+
+ // At each operation,
+ // 1. Scan DEF of outputs. If the DEF, allocate it
+ // 2. Scan DEF of inputs. If variable tensor, allocate it
+ // 3. Scan USE of inputs. Decrease the USE and deallocate if the USE is 0
+ for (const auto op_ind : _data.op_order)
+ {
+ const auto &op = graph()->operations().at(op_ind);
+ auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+ auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+
+ // Define outputs
+ for (const auto &ind : op_outputs)
+ {
+ if (!tensor_builder->isRegistered(ind))
+ continue;
+ assert(def_map.find(ind) != def_map.end());
+ if (def_map[ind])
+ {
+ def_map[ind] = 0;
+ tensor_builder->notifyFirstUse(ind);
+ }
+ }
+
+ // Scan variable tensors
+ // This tensor has features like constant. But OperandInfo and LowerInfo treat them as
+ // non-constant because of less memory usage by memory planning in here
+ for (const auto &ind : op_inputs)
+ {
+ if (!tensor_builder->isRegistered(ind))
+ continue;
+ const auto &operand = graph()->operands().at(ind);
+ if (operand.info().isVariable())
+ {
+ // The variable tensor with buffer is not supported yet
+ assert(operand.data() == nullptr);
+ assert(operand.getUses().size() == 1 && !operand.getDef().valid());
+ assert(uses_map[ind] == 1 && def_map[ind] == 0);
+ tensor_builder->notifyFirstUse(ind);
+ }
+ }
+
+ for (const auto &ind : op_inputs)
+ {
+ if (!tensor_builder->isRegistered(ind))
+ continue;
+ assert(uses_map.find(ind) != uses_map.end());
+ assert(uses_map[ind] > 0);
+ uses_map[ind]--;
+ if (uses_map[ind] == 0)
+ {
+ // plan for deallocation of static tensornode
+ tensor_builder->notifyLastUse(ind);
+ }
+ }
+ }
+
+ _data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
+ if (uses_map[ind] == 0)
+ {
+ tensor_builder->notifyLastUse(ind);
+ }
+ });
+
+ // Dispose and validate
+ for (const auto &ind : constants)
+ {
+ --uses_map[ind];
+ if (uses_map[ind] == 0) // To prevent notifyLastUse from being called twice
+ {
+ tensor_builder->notifyLastUse(ind);
+ }
+ }
+
+ assert(
+ std::all_of(uses_map.begin(), uses_map.end(),
+ [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
+
+ assert(
+ std::all_of(def_map.begin(), def_map.end(),
+ [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
+}
+
+ITensorRegistry *BackendContext::genTensors()
+{
+ graph()->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
+ if (external_operands().contains(ind))
+ return;
+
+ const auto frontend_layout = graph()->layout();
+ const auto backend_layout = operand_layouts().at(ind);
+ ir::OperandInfo backend_info{permuteShape(obj.shape(), frontend_layout, backend_layout),
+ obj.typeInfo(), obj.info().memAllocType(), obj.isConstant()};
+ tensor_builder->registerTensorInfo(ind, backend_info, backend_layout);
+ });
+
+ // TODO Get compiler options from compiler, and use it rather than getting it from Env
+ if (util::getConfigString(util::config::EXECUTOR) == "Linear")
+ {
+ planTensors();
+ }
+ else
+ {
+ // For the executors that does not have fixed linear execution order:
+ // To make tensors never be deallocated, this is a workaround to use static memory planner
+ graph()->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
+ if (tensor_builder->isRegistered(ind))
+ tensor_builder->notifyFirstUse(ind);
+ });
+ }
+
+ tensor_builder->prepare();
+
+ return tensor_registry.get();
+}
+
+FunctionMap BackendContext::genKernels()
+{
+ FunctionMap ret;
+
+ // kernel_gen
+ for (auto op_ind : _data.op_order)
+ {
+ auto fn_seq = kernel_gen->generate(op_ind);
+ ret.emplace_back(op_ind, std::move(fn_seq));
+ }
+
+ tensor_builder->allocate();
+
+ initConsts();
+
+ // NOTE For memory optimization, we want to free some operand data
+ const_cast<ir::Graph &>(*_data.graph)
+ .operands()
+ .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
+
+ for (auto &it : ret)
+ {
+ auto &fn_seq = it.second;
+ fn_seq->iterate([&](exec::IFunction &ifunc) {
+ ifunc.prepare();
+ tensor_builder->postFunctionPrepare();
+ });
+ }
+
+ return ret;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_BACKEND_CONTEXT_H__
+#define __ONERT_BACKEND_GPU_CL_BACKEND_CONTEXT_H__
+
+#include <backend/BackendContext.h>
+#include <util/ConfigSource.h>
+
+#include "ConstantInitializer.h"
+#include "KernelGenerator.h"
+#include "TensorBuilder.h"
+#include "open_cl/InferenceContext.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class BackendContext : public onert::backend::BackendContext
+{
+public:
+ BackendContext(const Backend *backend, ContextData &&data,
+ std::shared_ptr<ITensorRegistry> tensor_registry = nullptr,
+ std::shared_ptr<TensorBuilder> tensor_builder = nullptr,
+ std::shared_ptr<ConstantInitializer> constant_initializer = nullptr,
+ std::shared_ptr<KernelGenerator> kernel_gen = nullptr)
+ : onert::backend::BackendContext(backend, std::move(data), tensor_registry),
+ tensor_builder{tensor_builder}, constant_initializer{constant_initializer}, kernel_gen{
+ kernel_gen}
+ {
+ }
+
+ ITensorRegistry *genTensors() override;
+ FunctionMap genKernels() override;
+
+private:
+ void initConsts();
+ void planTensors();
+
+public:
+ std::shared_ptr<TensorBuilder> tensor_builder;
+ std::shared_ptr<ConstantInitializer> constant_initializer;
+ std::shared_ptr<KernelGenerator> kernel_gen;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_BACKEND_CONTEXT_H__
--- /dev/null
+set(LIB_ONERT_BACKEND_GPU_CL onert_backend_gpu_cl)
+
+nnas_find_package(Opencl_Headers QUIET)
+if(NOT Opencl_Headers_FOUND)
+ return()
+endif(NOT Opencl_Headers_FOUND)
+
+if(NOT BUILD_GPU_CL)
+ return()
+endif(NOT BUILD_GPU_CL)
+
+nnas_find_package(Farmhash QUIET)
+if(NOT Farmhash_FOUND)
+ return()
+endif(NOT Farmhash_FOUND)
+
+nnas_find_package(Abseil QUIET)
+if(NOT Abseil_FOUND)
+ return()
+endif(NOT Abseil_FOUND)
+
+file(GLOB_RECURSE SOURCES "*.cc")
+
+
+add_library(${LIB_ONERT_BACKEND_GPU_CL} SHARED ${SOURCES})
+
+target_include_directories(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
+
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE abseil)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE dl)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE farmhash)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE Headers)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE onert_core)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE nnfw_common)
+target_link_libraries(${LIB_ONERT_BACKEND_GPU_CL} PRIVATE nnfw_coverage)
+
+set_target_properties(${LIB_ONERT_BACKEND_GPU_CL} PROPERTIES OUTPUT_NAME backend_gpu_cl)
+
+if(CMAKE_BUILD_TYPE_LC STREQUAL "release")
+ add_custom_command(TARGET ${LIB_ONERT_BACKEND_GPU_CL} POST_BUILD
+ COMMAND ${CMAKE_STRIP} "--strip-unneeded" $<TARGET_FILE_NAME:${LIB_ONERT_BACKEND_GPU_CL}>)
+endif()
+
+install(TARGETS ${LIB_ONERT_BACKEND_GPU_CL} DESTINATION lib)
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClConstantInitializer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ClConstantInitializer::ClConstantInitializer(const ir::Operands &operands,
+ const std::shared_ptr<ITensorRegistry> &tensor_reg)
+ : _operands{operands}, _tensor_reg{tensor_reg}, _current_layout{ir::Layout::UNKNOWN}
+{
+ // DO NOTHING
+}
+
+void ClConstantInitializer::copyInputInitialize(const ir::Operation &node, uint32_t index)
+{
+ assert(node.getInputs().size() > index);
+
+ const auto &input_index = node.getInputs().at(index);
+ if (input_index.valid())
+ {
+ const auto &input_obj = _operands.at(input_index);
+ registerCopyInitializer(input_index, input_obj);
+ }
+}
+
+void ClConstantInitializer::permuteInputInitialize(const ir::Operation &node, uint32_t index)
+{
+ assert(node.getInputs().size() > index);
+
+ const auto &input_index = node.getInputs().at(index);
+ const auto &input_obj = _operands.at(input_index);
+ registerPermuteInitializer(input_index, input_obj);
+}
+
+// NOTE Workaround for 16b float type. Here, this is enough since only the size of bytes matters.
+using float16 = uint16_t;
+
+void ClConstantInitializer::registerCopyInitializer(const ir::OperandIndex &index,
+ const ir::Operand &obj)
+{
+ // For only CONSTANTS
+ // TODO Add to check if tensor has been allocated
+ if (!obj.isConstant())
+ return;
+
+ const auto type = obj.typeInfo().type();
+ using ir::DataType;
+
+ switch (type)
+ {
+ case DataType::FLOAT32:
+ _init_map[index] = copyInit<float>;
+ break;
+ default:
+ throw std::runtime_error("Not supported, yet");
+ break;
+ }
+}
+
+void ClConstantInitializer::registerPermuteInitializer(const ir::OperandIndex &index,
+ const ir::Operand &obj)
+{
+ // For only CONSTANTS
+ // TODO Add to check if tensor has been allocated
+ if (!obj.isConstant())
+ return;
+
+ const auto type = obj.typeInfo().type();
+ using ir::DataType;
+ using namespace std::placeholders;
+
+ switch (type)
+ {
+ case DataType::FLOAT32:
+ _init_map[index] = std::bind(permuteInit<float>, _1, _2, _current_layout);
+ break;
+ default:
+ throw std::runtime_error("Not supported, yet");
+ break;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_GPU_CL_CLCONSTANT_INITIALIZER_H__
+#define __ONERT_COMPILER_GPU_CL_CLCONSTANT_INITIALIZER_H__
+
+#include "ClTensorRegistry.h"
+
+#include <unordered_map>
+#include <functional>
+
+#include <ir/Coordinates.h>
+#include <ir/Layout.h>
+#include <ir/Operand.h>
+#include <ir/Operands.h>
+#include <ir/OperationVisitor.h>
+#include <backend/ITensorRegistry.h>
+#include <util/logging.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <typename T>
+static void Init(const onert::ir::Operand &model_obj, onert::backend::ITensor &obj, const bool copy,
+ const onert::ir::Layout frontend_layout = onert::ir::Layout::UNKNOWN)
+{
+ const auto shape = model_obj.shape();
+ assert(model_obj.data());
+ obj.access([&](::onert::backend::ITensor &tensor) {
+ switch (shape.rank())
+ {
+ case 0:
+ case 1:
+ case 2:
+ case 3:
+ case 4:
+ if (copy)
+ {
+ tensor.enqueueWriteBuffer(model_obj.data()->base(), true);
+ }
+ else
+ {
+ // NYI
+ (void)frontend_layout;
+ throw std::runtime_error{"Not yet supported"};
+ }
+ break;
+ default:
+ throw std::runtime_error{"Not yet supported"};
+ }
+ });
+}
+
+template <typename T>
+void copyInit(const onert::ir::Operand &model_obj, onert::backend::ITensor &obj)
+{
+ Init<T>(model_obj, obj, true);
+}
+
+template <typename T>
+void permuteInit(const onert::ir::Operand &model_obj, onert::backend::ITensor &obj,
+ const onert::ir::Layout frontend_layout)
+{
+ const bool copy = frontend_layout == obj.layout();
+ Init<T>(model_obj, obj, copy, frontend_layout);
+}
+
+class ClConstantInitializer : public ir::OperationVisitor
+{
+public:
+ void run()
+ {
+ assert(_tensor_reg);
+ for (const auto &it : _init_map)
+ {
+ const auto &ind = it.first;
+ const auto &fn = it.second;
+
+ const auto &model_obj = _operands.at(ind);
+ auto tensor_obj = _tensor_reg->getNativeITensor(ind);
+ assert(tensor_obj != nullptr);
+ fn(model_obj, *tensor_obj);
+ VERBOSE(FillOperandData) << "Fill data for operand " << ind << std::endl;
+ }
+ _init_map.clear();
+ }
+
+public:
+ ClConstantInitializer(const ir::Operands &operands,
+ const std::shared_ptr<ITensorRegistry> &tensor_reg);
+
+public:
+ using Initializer = std::function<void(const ir::Operand &, backend::ITensor &)>;
+
+public:
+ void registerDefaultInitializer(const ir::OperandIndex &index, const ir::Operand &obj)
+ {
+ registerPermuteInitializer(index, obj);
+ }
+ void registerCopyInitializer(const ir::OperandIndex &index, const ir::Operand &obj);
+ void registerPermuteInitializer(const ir::OperandIndex &index, const ir::Operand &obj);
+
+public:
+ void setLayout(ir::Layout layout) { _current_layout = layout; }
+ bool exist(const ir::OperandIndex &ind) { return _init_map.find(ind) != _init_map.end(); }
+
+public:
+protected:
+ void copyInputInitialize(const ir::Operation &node, uint32_t index);
+ void permuteInputInitialize(const ir::Operation &node, uint32_t index);
+
+protected:
+ const ir::Operands &_operands;
+ std::shared_ptr<ITensorRegistry> _tensor_reg;
+ std::unordered_map<ir::OperandIndex, Initializer> _init_map;
+ ir::Layout _current_layout;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_COMPILER_GPU_CL_CLCONSTANT_INITIALIZER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_GPU_CL_OPEN_CL_FUNCTION_H__
+#define __ONERT_GPU_CL_OPEN_CL_FUNCTION_H__
+
+#include <exec/IFunction.h>
+
+#include <vector>
+#include <memory>
+
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ClFunction : public ::onert::exec::IFunction
+{
+public:
+ ClFunction() : _gpu_operations(), _creation_context() {}
+
+public:
+ void configure(std::shared_ptr<CreationContext> creation_context)
+ {
+ _creation_context = creation_context;
+ }
+
+ void add_operation(std::unique_ptr<GPUOperation> gpu_operation)
+ {
+ _gpu_operations.push_back(std::move(gpu_operation));
+ }
+
+ void run() override
+ {
+ for (const auto &gpu_operation : _gpu_operations)
+ {
+ if (!gpu_operation->AddToQueue(_creation_context->queue).ok())
+ {
+ throw std::runtime_error("Failed to AddToQueue.");
+ }
+ }
+ }
+
+ void prepare() override
+ {
+ for (const auto &gpu_operation : _gpu_operations)
+ {
+ if (!gpu_operation->Compile(*_creation_context).ok())
+ {
+ throw std::runtime_error("Failed to Compile.");
+ }
+
+ if (!gpu_operation->UpdateParams().ok())
+ {
+ throw std::runtime_error("Failed to UpdateParams.");
+ }
+ }
+ }
+
+private:
+ std::vector<std::unique_ptr<GPUOperation>> _gpu_operations;
+ std::shared_ptr<CreationContext> _creation_context;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_GPU_CL_OPEN_CL_FUNCTION_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_ACL_COMMON_MEMORY_MANAGER_H__
+#define __ONERT_BACKEND_ACL_COMMON_MEMORY_MANAGER_H__
+
+#include <cassert>
+
+#include "ir/OperandIndexMap.h"
+#include "ir/Shape.h"
+#include "open_cl/ClContext.h"
+#include "open_cl/InferenceContext.h"
+#include "open_cl/Status.h"
+#include "open_cl/StorageTypeUtil.h"
+#include "open_cl/TensorType.h"
+#include "util/logging.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <typename T_ITensor, typename T_Tensor> class ClMemoryManager
+{
+public:
+ ClMemoryManager(CLContext *context) : _context{context} {}
+
+ virtual ~ClMemoryManager() = default;
+
+ virtual void allocate(void)
+ {
+ for (const auto &tensor_entry : _tensors)
+ {
+ auto tensor = tensor_entry.second;
+ const auto &t = tensor_reserver_.Get(tensor_entry.first.value());
+ const auto &shape = t->shape;
+ const auto &descriptor = t->descriptor;
+ if (!CreateTensor(*_context, shape, descriptor, tensor->handle()).ok())
+ {
+ return;
+ }
+ }
+ }
+
+ virtual void deallocate(void)
+ {
+ // NYI
+ }
+
+ virtual void startLifetime(const ir::OperandIndex &)
+ { /* DO NOTHING */
+ }
+ virtual void finishLifetime(const ir::OperandIndex &)
+ { /* DO NOTHING */
+ }
+
+ void buildTensor(const ir::OperandIndex &ind, const ir::OperandInfo &info,
+ InferenceContext::CreateInferenceInfo create_info,
+ std::shared_ptr<Environment> environment, DeviceInfo &device_info)
+ {
+ ValueId max_id = 0;
+ auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
+ const auto shape = info.shape();
+
+ auto tensor = std::make_shared<T_Tensor>(shape.rank(), shape, environment);
+ _tensors[ind] = tensor;
+
+ BHWC t_shape;
+ switch (shape.rank())
+ {
+ case 1:
+ // B layout
+ t_shape = BHWC(shape.dim(0), 1, 1, 1);
+ break;
+ case 2:
+ // BC layout
+ t_shape = BHWC(shape.dim(0), 1, 1, shape.dim(1));
+ break;
+ case 3:
+ // BWC layout
+ t_shape = BHWC(shape.dim(0), 1, shape.dim(1), shape.dim(2));
+ break;
+ case 4:
+ // BHWC layout
+ t_shape = BHWC(shape.dim(0), shape.dim(1), shape.dim(2), shape.dim(3));
+ break;
+ default:
+ break;
+ }
+
+ TensorStorageType storage_type = create_info.storage_type;
+ Layout layout = t_shape.b == 1 ? Layout::HWC : Layout::BHWC;
+
+ ValueId id = ind.value();
+ storage_type = SelectBestStorageType(device_info, t_shape, storage_type, data_type, layout);
+ auto dummy = std::make_shared<InferenceContext::DummyTensor>();
+ dummy->shape = t_shape;
+ dummy->descriptor = TensorDescriptor{data_type, storage_type, layout};
+ tensor_reserver_.Add(id, dummy);
+
+ max_id = std::max(max_id, id);
+
+ tensor_reserver_.SetNext(max_id + 1);
+ }
+
+ ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &tensors(void) { return _tensors; }
+
+ InferenceContext::TensorReserver &tensorReservers(void) { return tensor_reserver_; }
+
+private:
+ ir::OperandIndexMap<std::shared_ptr<T_Tensor>> _tensors;
+ InferenceContext::TensorReserver tensor_reserver_;
+ CLContext *_context;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_ACL_COMMON_MEMORY_MANAGER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_CL_TENSOR_BUILDER_H__
+#define __ONERT_BACKEND_CL_TENSOR_BUILDER_H__
+
+#include <memory>
+#include <queue>
+
+#include "ClTensorManager.h"
+#include "ClTensorRegistry.h"
+#include "ParentInfo.h"
+
+#include "open_cl/TensorType.h"
+#include "open_cl/TensorTypeUtil.h"
+#include "open_cl/ClDevice.h"
+#include "open_cl/InferenceContext.h"
+
+#include "ir/OperandIndexMap.h"
+#include "ir/OperandIndexSequence.h"
+#include <ir/Operands.h>
+#include <util/Utils.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class UsesType
+{
+ FIRST,
+ LAST
+};
+
+template <typename T_ITensor, typename T_Tensor> class ClTensorBuilder
+{
+public:
+ using T_ClTensorManager = ClTensorManager<T_ITensor, T_Tensor>;
+
+ ClTensorBuilder(const ir::Operands &operands, T_ClTensorManager *tensor_mgr,
+ InferenceContext::CreateInferenceInfo create_info,
+ const std::shared_ptr<Environment> &environment);
+
+ /**
+ * @brief Register tensor information to allocate on ACL-CL backend
+ * @param[in] ind Operand index
+ * @param[in] info Tensor information
+ * @param[in] layout Tensor data layout
+ */
+ void registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info,
+ ir::Layout backend_layout);
+
+ void notifyFirstUse(const ir::OperandIndex &);
+ void notifyLastUse(const ir::OperandIndex &);
+
+ bool isRegistered(const ir::OperandIndex &) const;
+
+ void prepare();
+ void allocate();
+ void postFunctionPrepare();
+
+ T_ClTensorManager *cl_tensor_manager(void) { return _tensor_mgr.get(); }
+
+ void setUsesCount(const ir::OperandIndex &index, size_t num_uses)
+ {
+ assert(_uses_count_map.find(index) != _uses_count_map.end() ? _uses_count_map[index] == num_uses
+ : true);
+ _uses_count_map[index] = num_uses;
+ }
+
+ void parent_map(std::unordered_map<ir::OperandIndex, ParentInfo> &&parent_map)
+ {
+ _parent_map = std::move(parent_map);
+ }
+
+ bool areSubTensorsOf(const ir::OperandIndex &parent, const ir::OperandIndexSequence &seq);
+
+ /**
+ * @brief Check child tensor is allocated as subtensor of parent tensor
+ * @param[in] parent Index of parent
+ * @param[in] child Index of child
+ * @return @c true if child is allocated as subtensor of parent, otherwise @c false
+ */
+ bool isSubTensorOf(const ir::OperandIndex &parent, const ir::OperandIndex &child);
+
+private:
+ void buildTensors(void);
+ ir::OperandIndex findRootParent(ir::OperandIndex index);
+
+private:
+ const ir::Operands &_operands;
+ ir::OperandIndexMap<ir::OperandInfo> _tensor_info_map;
+ ir::OperandIndexMap<ir::Layout> _tensor_layout_map;
+ ir::OperandIndexMap<size_t> _uses_count_map;
+
+ std::unique_ptr<T_ClTensorManager> _tensor_mgr;
+ InferenceContext::CreateInferenceInfo _create_info;
+ std::shared_ptr<Environment> _environment;
+
+ // for linear executor
+ std::vector<std::pair<UsesType, ir::OperandIndex>> _lifetime_seq;
+
+ // Extra info for concat elimination
+ ir::OperandIndexMap<ParentInfo> _parent_map;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#include <cassert>
+#include <stack>
+
+#include "util/logging.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <typename T_ITensor, typename T_Tensor>
+ClTensorBuilder<T_ITensor, T_Tensor>::ClTensorBuilder(
+ const ir::Operands &operands, T_ClTensorManager *tensor_mgr,
+ InferenceContext::CreateInferenceInfo create_info,
+ const std::shared_ptr<Environment> &environment)
+ : _operands{operands}, _tensor_mgr{tensor_mgr}, _create_info{create_info}, _environment{
+ environment}
+{
+ assert(_tensor_mgr);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::registerTensorInfo(const ir::OperandIndex &ind,
+ const ir::OperandInfo &info,
+ ir::Layout backend_layout)
+{
+ assert(_tensor_mgr->constTensors().size() == 0);
+ assert(_tensor_mgr->nonconstTensors().size() == 0);
+
+ _uses_count_map[ind] = _operands.at(ind).getUses().size();
+
+ _tensor_info_map.emplace(ind, info);
+ _tensor_layout_map.insert({ind, backend_layout});
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::notifyFirstUse(const ir::OperandIndex &ind)
+{
+ _lifetime_seq.emplace_back(UsesType::FIRST, ind);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::notifyLastUse(const ir::OperandIndex &ind)
+{
+ _lifetime_seq.emplace_back(UsesType::LAST, ind);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+bool ClTensorBuilder<T_ITensor, T_Tensor>::isRegistered(const ir::OperandIndex &ind) const
+{
+ return _tensor_info_map.find(ind) != _tensor_info_map.end();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::prepare(void)
+{
+ buildTensors();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::allocate(void)
+{
+ // Update lifetime sequence to apply subtensor optimization
+
+ std::unordered_map<ir::OperandIndex, ir::OperandIndex> root_map;
+ std::function<ir::OperandIndex &(ir::OperandIndex)> find_root =
+ [&](ir::OperandIndex ind) -> ir::OperandIndex & {
+ ir::OperandIndex &ret = root_map[ind];
+
+ // We know the root parent value already
+ if (ret.valid())
+ return ret;
+
+ auto itr = _parent_map.find(ind);
+ if (itr == _parent_map.end())
+ {
+ // If there is no parent, let's store the value of itself
+ return ret = ind;
+ }
+ else
+ {
+ return ret = find_root(itr->second.parent);
+ }
+ };
+
+ ir::OperandIndexMap<bool> first_use_check;
+ ir::OperandIndexMap<bool> last_use_check;
+ std::map<size_t, std::pair<UsesType, ir::OperandIndex>> lifetime_map;
+ for (size_t i = 0; i < _lifetime_seq.size(); i++)
+ {
+ auto &entry = _lifetime_seq[i];
+ if (entry.first != UsesType::FIRST)
+ continue;
+ auto root_ind = find_root(entry.second);
+ if (first_use_check[root_ind])
+ continue;
+ first_use_check[root_ind] = true;
+ lifetime_map[i] = {UsesType::FIRST, root_ind};
+ }
+
+ for (int i = _lifetime_seq.size() - 1; i >= 0; i--)
+ {
+ auto &entry = _lifetime_seq[i];
+ if (entry.first != UsesType::LAST)
+ continue;
+ auto root_ind = find_root(entry.second);
+ if (last_use_check[root_ind])
+ continue;
+ last_use_check[root_ind] = true;
+ lifetime_map[i] = {UsesType::LAST, root_ind};
+ }
+
+ for (auto &entry : lifetime_map)
+ {
+ auto &use = entry.second;
+ auto use_type = use.first;
+ auto use_index = use.second;
+ assert(use_index.valid());
+ if (use_type == UsesType::FIRST)
+ _tensor_mgr->startLifetime(use_index);
+ else
+ _tensor_mgr->finishLifetime(use_index);
+ }
+
+ _tensor_mgr->allocateConsts();
+
+ // TODO Since `_parent_map` is filled for all Concat nodes even if the node this backend uses
+ // After refactoring BackendContext we can uncomment this
+ // assert(_tensor_info_map.size() ==
+ // _tensor_mgr->nonconstTensors().size() + num of constants of _tensor_info_map +
+ // _parent_map.size());
+ _tensor_mgr->allocateNonconsts();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::postFunctionPrepare(void)
+{
+ _tensor_mgr->tryDeallocConstants();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorBuilder<T_ITensor, T_Tensor>::buildTensors(void)
+{
+ assert(_tensor_mgr->constTensors().size() == 0);
+ assert(_tensor_mgr->nonconstTensors().size() == 0);
+ // Normal tensors
+ for (auto &entry : _tensor_info_map)
+ {
+ auto ind = entry.first;
+ if (_parent_map.count(ind) > 0)
+ continue;
+
+ const auto &info = entry.second;
+ _tensor_mgr->buildTensor(ind, info, _create_info, _environment, _environment->device().info_);
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_ACL_COMMON_TEMPL_TENSOR_BUILDER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_ACL_COMMON_TENSOR_MANAGER_H__
+#define __ONERT_BACKEND_ACL_COMMON_TENSOR_MANAGER_H__
+
+#include "ClMemoryManager.h"
+
+#include "open_cl/InferenceContext.h"
+#include "open_cl/TensorType.h"
+
+#include "ir/OperandInfo.h"
+#include "ir/OperandIndexMap.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <typename T_ITensor, typename T_Tensor> class ClTensorManager
+{
+public:
+ using T_ClMemoryManager = ClMemoryManager<T_ITensor, T_Tensor>;
+
+ ClTensorManager(T_ClMemoryManager *const_mgr, T_ClMemoryManager *nonconst_mgr);
+
+ virtual ~ClTensorManager() = default;
+
+ void allocateConsts(void);
+ void allocateNonconsts(void);
+ void deallocateConsts(void);
+ void deallocateNonconsts(void);
+
+ void buildTensor(const ir::OperandIndex &ind, const ir::OperandInfo &info,
+ InferenceContext::CreateInferenceInfo create_info,
+ std::shared_ptr<Environment> environment, DeviceInfo &device_info);
+
+ std::shared_ptr<T_ITensor> findTensorAsParent(const ir::OperandIndex &ind);
+
+ void startLifetime(const ir::OperandIndex &ind);
+ void finishLifetime(const ir::OperandIndex &ind);
+
+ std::shared_ptr<T_ITensor> at(const ir::OperandIndex &ind);
+ std::shared_ptr<InferenceContext::DummyTensor> atR(const ir::OperandIndex &ind);
+
+ InferenceContext::TensorReserver &constTensorReservers(void);
+ InferenceContext::TensorReserver &nonconstTensorReservers(void);
+
+ ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &constTensors(void);
+ ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &nonconstTensors(void);
+
+ void iterate(const std::function<void(const ir::OperandIndex &)> &fn);
+
+ void tryDeallocConstants(void);
+
+private:
+ std::unique_ptr<T_ClMemoryManager> _const_mgr;
+ std::unique_ptr<T_ClMemoryManager> _nonconst_mgr;
+ ir::OperandIndexMap<T_ClMemoryManager &> _ind_to_mgr;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#include <cassert>
+#include "util/logging.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <typename T_ITensor, typename T_Tensor>
+ClTensorManager<T_ITensor, T_Tensor>::ClTensorManager(T_ClMemoryManager *const_mgr,
+ T_ClMemoryManager *nonconst_mgr)
+ : _const_mgr{const_mgr}, _nonconst_mgr{nonconst_mgr}
+{
+ // DO NOTHING
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::allocateConsts(void)
+{
+ _const_mgr->allocate();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::allocateNonconsts(void)
+{
+ _nonconst_mgr->allocate();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::deallocateConsts(void)
+{
+ _const_mgr->deallocate();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::deallocateNonconsts(void)
+{
+ _nonconst_mgr->deallocate();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::buildTensor(
+ const ir::OperandIndex &ind, const ir::OperandInfo &info,
+ InferenceContext::CreateInferenceInfo create_info, std::shared_ptr<Environment> environment,
+ DeviceInfo &device_info)
+{
+ assert(_ind_to_mgr.find(ind) == _ind_to_mgr.end());
+
+ if (info.isConstant())
+ {
+ _const_mgr->buildTensor(ind, info, create_info, environment, device_info);
+ _ind_to_mgr.insert({ind, *_const_mgr});
+ }
+ else
+ {
+ _nonconst_mgr->buildTensor(ind, info, create_info, environment, device_info);
+ _ind_to_mgr.insert({ind, *_nonconst_mgr});
+ }
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::startLifetime(const ir::OperandIndex &ind)
+{
+ assert(_ind_to_mgr.find(ind) != _ind_to_mgr.end());
+ _ind_to_mgr.at(ind).startLifetime(ind);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::finishLifetime(const ir::OperandIndex &ind)
+{
+ assert(_ind_to_mgr.find(ind) != _ind_to_mgr.end());
+ _ind_to_mgr.at(ind).finishLifetime(ind);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+std::shared_ptr<T_ITensor> ClTensorManager<T_ITensor, T_Tensor>::at(const ir::OperandIndex &ind)
+{
+ if (_ind_to_mgr.find(ind) == _ind_to_mgr.end())
+ return nullptr;
+
+ auto &tensors = _ind_to_mgr.at(ind).tensors();
+ if (tensors.find(ind) != tensors.end())
+ {
+ return tensors.at(ind);
+ }
+
+ return nullptr;
+}
+
+template <typename T_ITensor, typename T_Tensor>
+ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &
+ClTensorManager<T_ITensor, T_Tensor>::constTensors(void)
+{
+ return _const_mgr->tensors();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+ir::OperandIndexMap<std::shared_ptr<T_Tensor>> &
+ClTensorManager<T_ITensor, T_Tensor>::nonconstTensors(void)
+{
+ return _nonconst_mgr->tensors();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+std::shared_ptr<InferenceContext::DummyTensor>
+ClTensorManager<T_ITensor, T_Tensor>::atR(const ir::OperandIndex &ind)
+{
+ if (_nonconst_mgr->tensorReservers().HaveTensor(ind.value()))
+ {
+ return _nonconst_mgr->tensorReservers().Get(ind.value());
+ }
+ else if (_const_mgr->tensorReservers().HaveTensor(ind.value()))
+ {
+ return _const_mgr->tensorReservers().Get(ind.value());
+ }
+ return nullptr;
+}
+
+template <typename T_ITensor, typename T_Tensor>
+InferenceContext::TensorReserver &ClTensorManager<T_ITensor, T_Tensor>::constTensorReservers(void)
+{
+ return _const_mgr->tensorReservers();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+InferenceContext::TensorReserver &
+ClTensorManager<T_ITensor, T_Tensor>::nonconstTensorReservers(void)
+{
+ return _nonconst_mgr->tensorReservers();
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::iterate(
+ const std::function<void(const ir::OperandIndex &)> &fn)
+{
+ for (auto it : _nonconst_mgr->tensors())
+ fn(it.first);
+
+ for (auto it : _const_mgr->tensors())
+ fn(it.first);
+}
+
+template <typename T_ITensor, typename T_Tensor>
+void ClTensorManager<T_ITensor, T_Tensor>::tryDeallocConstants(void)
+{
+ // NYI
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_ACL_COMMON_TENSOR_MANAGER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_TENSOR_REGISTRY_H__
+#define __ONERT_BACKEND_GPU_CL_TENSOR_REGISTRY_H__
+
+#include "backend/ITensorRegistry.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+/**
+ * @brief Tensor registry class for acl backends
+ *
+ * This is implemented as a wrapper of AclTensorManager.
+ */
+template <typename T_ClTensorManager> class ClTensorRegistry : public ITensorRegistry
+{
+public:
+ ClTensorRegistry(T_ClTensorManager *tensor_mgr) : _tensor_mgr{tensor_mgr} {}
+
+ ITensor *getITensor(const ir::OperandIndex &ind) override { return _tensor_mgr->at(ind).get(); }
+
+ ITensor *getNativeITensor(const ir::OperandIndex &ind) override { return getITensor(ind); }
+
+ auto getClTensor(const ir::OperandIndex &ind) { return _tensor_mgr->at(ind).get(); }
+
+ auto getClTensorReserver(const ir::OperandIndex &ind) { return _tensor_mgr->atR(ind); }
+
+private:
+ T_ClTensorManager *_tensor_mgr;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_TENSOR_REGISTRY_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Config.h"
+
+#include <dlfcn.h>
+#include "open_cl/OpenclWrapper.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+Config::~Config() { UnloadOpenCL(_handle); }
+
+bool Config::initialize()
+{
+ if (LoadOpenCL(&_handle).ok())
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout) { return ir::Layout::NHWC; }
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_CONFIG_H__
+#define __ONERT_BACKEND_GPU_CL_CONFIG_H__
+
+#include <backend/IConfig.h>
+#include <memory>
+#include <util/ITimer.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class Config : public IConfig
+{
+public:
+ virtual ~Config();
+
+public:
+ std::string id() override { return "gpu_cl"; }
+ bool initialize() override;
+ ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override;
+ bool supportPermutation() override { return true; }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return true; }
+ std::unique_ptr<util::ITimer> timer() override { return std::make_unique<util::CPUTimer>(); }
+
+private:
+ void *_handle = nullptr;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_CONFIG_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantInitializer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ConstantInitializer::ConstantInitializer(const ir::Operands &operands,
+ const std::shared_ptr<ITensorRegistry> &tensor_reg)
+ : ClConstantInitializer{operands, tensor_reg}
+{
+ // DO NOTHING
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_CONSTANT_INITIALIZER_H__
+#define __ONERT_BACKEND_GPU_CL_CONSTANT_INITIALIZER_H__
+
+#include "ClConstantInitializer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ConstantInitializer : public ClConstantInitializer
+{
+public:
+ ConstantInitializer(const ir::Operands &operands,
+ const std::shared_ptr<ITensorRegistry> &tensor_reg);
+
+public:
+ using ClConstantInitializer::visit;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_CONSTANT_INITIALIZER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <stdexcept>
+
+#include <backend/basic/KernelGeneratorBase.h>
+
+#include "KernelGenerator.h"
+
+#include "ClTensorRegistry.h"
+#include "ClFunction.h"
+#include "TensorManager.h"
+
+#include "open_cl/selectors/ConvolutionSelector.h"
+#include "open_cl/selectors/DwConvolutionSelector.h"
+#include "open_cl/selectors/SimpleSelectors.h"
+
+#include "ir/Operations.h"
+#include "ir/Operations.Include.h"
+#include "ir/Index.h"
+#include "ir/DataType.h"
+#include "ir/InternalType.h"
+#include "exec/NopFunction.h"
+#include "exec/FunctionSequence.h"
+#include "util/logging.h"
+#include "util/Utils.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
+
+template <typename AttrT>
+void UpdatePadding(const ir::PaddingType type, const BHWC &input_shape, AttrT *attr)
+{
+ if (type == ir::PaddingType::SAME)
+ {
+ attr->padding = CalculateSamePadding(input_shape, *attr);
+ }
+ else
+ {
+ attr->padding.prepended = HW(0, 0);
+ attr->padding.appended = HW(0, 0);
+ }
+}
+
+gpu_cl::PoolingType convertPoolType(ir::operation::Pool2D::PoolType type_ir)
+{
+ switch (type_ir)
+ {
+ case ir::operation::Pool2D::PoolType::AVG:
+ return gpu_cl::PoolingType::AVERAGE;
+ case ir::operation::Pool2D::PoolType::MAX:
+ return gpu_cl::PoolingType::MAX;
+ default:
+ throw std::runtime_error("gpu_Cl KernelGenerator : Not supported operation yet");
+ }
+}
+
+KernelGenerator::KernelGenerator(const ir::Graph &graph,
+ const std::shared_ptr<TensorBuilder> &tensor_builder,
+ const std::shared_ptr<ClTensorRegistry<TensorManager>> &tensor_reg,
+ const std::shared_ptr<CreationContext> &creation_context)
+ : basic::KernelGeneratorBase{graph}, _ctx(graph.operands()),
+ _operations_ctx(graph.operations()), _current_layout{graph.layout()},
+ _tensor_builder(tensor_builder), _tensor_reg(tensor_reg), _creation_context(creation_context)
+{
+}
+
+std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind)
+{
+ auto ret = std::make_unique<exec::FunctionSequence>();
+ ret->enableDynamicShapeInferer(false);
+
+ const auto &op = _graph.operations().at(ind);
+ op.accept(*this);
+ ret->append(releaseFunction());
+ return ret;
+}
+
+void KernelGenerator::visit(const ir::operation::BinaryArithmetic &node)
+{
+ const auto ofm_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS)};
+
+ // const auto activation = node.param().activation;
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(lhs_index)->descriptor);
+ auto lhs_shape = _tensor_reg->getClTensorReserver(lhs_index)->shape;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(rhs_index)->descriptor);
+ auto rhs_shape = _tensor_reg->getClTensorReserver(rhs_index)->shape;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(ofm_index)->descriptor);
+ auto out_shape = _tensor_reg->getClTensorReserver(ofm_index)->shape;
+
+ auto fn = std::make_unique<ClFunction>();
+
+ std::unique_ptr<GPUOperation> gpu_op;
+ switch (node.param().arithmetic_type)
+ {
+ case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
+ {
+ std::vector<int> channels(2);
+ channels[0] = lhs_shape.c;
+ channels[1] = rhs_shape.c;
+ SelectAdd(op_def, channels, out_shape.c, &gpu_op);
+
+ auto ofm_tensor = _tensor_reg->getClTensor(ofm_index);
+ auto lhs_tensor = _tensor_reg->getClTensor(lhs_index);
+ auto rhs_tensor = _tensor_reg->getClTensor(rhs_index);
+ gpu_op->SetSrc(lhs_tensor->handle(), ir::operation::BinaryArithmetic::Input::LHS);
+ gpu_op->SetSrc(rhs_tensor->handle(), ir::operation::BinaryArithmetic::Input::RHS);
+ gpu_op->SetDst(ofm_tensor->handle(), 0);
+
+ fn->configure(_creation_context);
+ fn->add_operation(std::move(gpu_op));
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
+ {
+ // NYI
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
+ {
+ // NYI
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::DIV:
+ {
+ // NYI
+ break;
+ }
+ default:
+ assert(false && "The BinaryArithmetic operation supports only binary arithmetic operations");
+ break;
+ }
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::Conv2D &node)
+{
+ auto output{node.getOutputs().at(0)};
+
+ auto input{node.getInputs().at(ir::operation::Conv2D::INPUT)};
+ auto kernel{node.getInputs().at(ir::operation::Conv2D::KERNEL)};
+ auto bias{node.getInputs().at(ir::operation::Conv2D::BIAS)};
+
+ const auto param = node.param();
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(input)->descriptor);
+
+ auto input_shape = _tensor_reg->getClTensorReserver(input)->shape;
+ auto kernel_shape = _tensor_reg->getClTensorReserver(kernel)->shape;
+ auto output_shape = _tensor_reg->getClTensorReserver(output)->shape;
+ auto bias_shape = _tensor_reg->getClTensorReserver(bias)->shape;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output)->descriptor);
+
+ ModelHints hints;
+ std::unique_ptr<GPUOperation> gpu_op; // = InitSingleOpSubgraph(inputs, outputs, gpu_subgraph);
+
+ auto input_tensor = _tensor_reg->getClTensor(input);
+ auto kernel_tensor = _tensor_reg->getClTensor(kernel);
+ auto bias_tensor = _tensor_reg->getClTensor(bias);
+ auto output_tensor = _tensor_reg->getClTensor(output);
+
+ gpu_cl::Convolution2DAttributes attr;
+ attr.strides = ToHW(param.stride.vertical, param.stride.horizontal);
+ attr.dilations = HW(std::max(static_cast<u_int32_t>(1), param.dilation.height_factor),
+ std::max(static_cast<u_int32_t>(1), param.dilation.width_factor));
+
+ bool is_weight = (_ctx.at(kernel).isConstant() ? true : false);
+
+ if (is_weight)
+ {
+ attr.weights.id = kernel.value();
+ attr.weights.shape.o = kernel_shape.b;
+ attr.weights.shape.h = kernel_shape.h;
+ attr.weights.shape.w = kernel_shape.w;
+ attr.weights.shape.i = kernel_shape.c;
+ attr.weights.data.resize(kernel_shape.DimensionsProduct());
+ memcpy(attr.weights.data.data(), _ctx.at(kernel).data()->base(), kernel_tensor->total_size());
+ }
+
+ attr.bias.id = bias.value();
+ // TODO Modify
+ attr.bias.shape.v = bias_shape.b != 1 ? bias_shape.b : bias_shape.c;
+ attr.bias.data.resize(bias_shape.DimensionsProduct());
+ memcpy(attr.bias.data.data(), _ctx.at(bias).data()->base(), bias_tensor->total_size());
+
+ UpdatePadding(param.padding.type, input_shape, &attr);
+
+ gpu_op = SelectConvolution(attr, output_shape, _creation_context->GetDeviceInfo(), op_def, hints);
+ gpu_op->SetSrc(input_tensor->handle(), ir::operation::Conv2D::INPUT);
+
+ auto fn = std::make_unique<ClFunction>();
+
+ fn->configure(_creation_context);
+
+ const auto activation = node.param().activation;
+
+ switch (activation)
+ {
+ case ir::Activation::NONE:
+ {
+ gpu_op->SetDst(output_tensor->handle(), 0);
+ fn->add_operation(std::move(gpu_op));
+ break;
+ }
+ case ir::Activation::RELU6:
+ {
+ std::unique_ptr<GPUOperation> gpu_op_1;
+ OperationDef op_def_1;
+ std::shared_ptr<Tensor> new_tensor = std::make_shared<Tensor>();
+
+ _new_tensors[output] = new_tensor;
+ if (!CreateTensor(*_creation_context->context, output_shape,
+ _tensor_reg->getClTensorReserver(output)->descriptor, new_tensor.get())
+ .ok())
+ {
+ throw std::runtime_error("Error CreateTensor.");
+ }
+
+ gpu_op->SetDst(new_tensor.get(), 0);
+ fn->add_operation(std::move(gpu_op));
+ op_def_1.precision = CalculationsPrecision::F32;
+ op_def_1.src_tensors.push_back(_tensor_reg->getClTensorReserver(output)->descriptor);
+ op_def_1.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output)->descriptor);
+
+ // - ReLU6: clip = 6, alpha = 0
+ ReLUAttributes attr_1;
+ attr_1.clip = 6;
+ attr_1.alpha = 0;
+ gpu_op_1 = SelectReLU(attr_1, op_def_1);
+
+ gpu_op_1->SetSrc(new_tensor.get(), 0);
+ gpu_op_1->SetDst(output_tensor->handle(), 0);
+ fn->add_operation(std::move(gpu_op_1));
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("gpu_cl KernelGenerator : Not supported operation yet");
+ }
+ }
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::DepthwiseConv2D &node)
+{
+ using ir::operation::DepthwiseConv2D;
+
+ const auto ofm_index{node.getOutputs().at(0)};
+ const auto ifm_index{node.getInputs().at(DepthwiseConv2D::Input::INPUT)};
+ const auto ker_index{node.getInputs().at(DepthwiseConv2D::Input::KERNEL)};
+ const auto bias_index{node.getInputs().at(DepthwiseConv2D::Input::BIAS)};
+
+ const auto stride = node.param().stride;
+ const auto dilation = node.param().dilation;
+ const auto padding = node.param().padding;
+
+ const auto multiplier = node.param().multiplier;
+
+ auto ofm_tensor = _tensor_reg->getClTensor(ofm_index);
+ auto ifm_tensor = _tensor_reg->getClTensor(ifm_index);
+ auto ker_tensor = _tensor_reg->getClTensor(ker_index);
+ auto bias_tensor = _tensor_reg->getClTensor(bias_index);
+
+ bool is_weight = (_ctx.at(ker_index).isConstant() ? true : false);
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(ifm_index)->descriptor);
+ auto input_shape = _tensor_reg->getClTensorReserver(ifm_index)->shape;
+
+ auto ker_shape = _tensor_reg->getClTensorReserver(ker_index)->shape;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(ofm_index)->descriptor);
+ auto out_shape = _tensor_reg->getClTensorReserver(ofm_index)->shape;
+ auto bias_shape = _tensor_reg->getClTensorReserver(bias_index)->shape;
+
+ DepthwiseConvolution2DAttributes attr;
+ attr.strides = ToHW(stride.vertical, stride.horizontal);
+ attr.dilations = HW(std::max(static_cast<u_int32_t>(1), dilation.height_factor),
+ std::max(static_cast<u_int32_t>(1), dilation.width_factor));
+
+ if (is_weight)
+ {
+ attr.weights.id = ker_index.value();
+ attr.weights.shape.o = ker_shape.b;
+ attr.weights.shape.h = ker_shape.h;
+ attr.weights.shape.w = ker_shape.w;
+ attr.weights.shape.i = ker_shape.c;
+ attr.weights.data.resize(ker_shape.DimensionsProduct());
+ memcpy(attr.weights.data.data(), _ctx.at(ker_index).data()->base(), ker_tensor->total_size());
+ }
+ attr.bias.id = bias_index.value();
+ attr.bias.shape.v = bias_shape.b != 1 ? bias_shape.b : bias_shape.c;
+ attr.bias.data.resize(bias_shape.DimensionsProduct());
+ memcpy(attr.bias.data.data(), _ctx.at(bias_index).data()->base(), bias_tensor->total_size());
+ UpdatePadding(padding.type, input_shape, &attr);
+
+ if (multiplier != 1)
+ {
+ const int input_depth = input_shape.c;
+ const int filter_height = ker_shape.h;
+ const int filter_width = ker_shape.w;
+ const int output_depth = out_shape.c;
+
+ InternalTensor<OHWI, DataType::FLOAT32> weights;
+ weights.id = attr.weights.id;
+ weights.shape = OHWI(output_depth, filter_height, filter_width, input_depth);
+ weights.data.resize(weights.shape.DimensionsProduct());
+ float *dst = &weights.data[0];
+ for (int j = 0; j < output_depth; ++j)
+ {
+ const float *src = attr.weights.data.data() + j;
+ for (int i = 0; i < filter_height * filter_width; ++i)
+ {
+ *dst = *src;
+ dst++;
+ src += output_depth;
+ }
+ }
+ attr.weights = std::move(weights);
+ }
+
+ auto fn = std::make_unique<ClFunction>();
+ std::unique_ptr<GPUOperation> gpu_op;
+
+ if (is_weight)
+ {
+ gpu_op = SelectDWConvolution(attr, _creation_context->GetDeviceInfo(), op_def);
+ }
+ else
+ {
+ if (ker_shape.b != 1)
+ {
+ throw std::runtime_error(
+ "No support of depthwise runtime weights with channel multiplier != 1");
+ }
+ gpu_op = SelectDWConvolutionDynamicWeights(attr, _creation_context->GetDeviceInfo(), op_def);
+ }
+
+ gpu_op->SetSrc(ifm_tensor->handle(), ir::operation::DepthwiseConv2D::Input::INPUT);
+
+ fn->configure(_creation_context);
+
+ const auto activation = node.param().activation;
+
+ switch (activation)
+ {
+ case ir::Activation::NONE:
+ {
+ gpu_op->SetDst(ofm_tensor->handle(), 0);
+ fn->add_operation(std::move(gpu_op));
+ break;
+ }
+ case ir::Activation::RELU6:
+ {
+ std::unique_ptr<GPUOperation> gpu_op_1;
+ OperationDef op_def_1;
+ std::shared_ptr<Tensor> new_tensor = std::make_shared<Tensor>();
+
+ _new_tensors[ofm_index] = new_tensor;
+ if (!CreateTensor(*_creation_context->context, out_shape,
+ _tensor_reg->getClTensorReserver(ofm_index)->descriptor, new_tensor.get())
+ .ok())
+ {
+ throw std::runtime_error("Error CreateTensor.");
+ }
+
+ gpu_op->SetDst(new_tensor.get(), 0);
+ fn->add_operation(std::move(gpu_op));
+ op_def_1.precision = CalculationsPrecision::F32;
+ op_def_1.src_tensors.push_back(_tensor_reg->getClTensorReserver(ofm_index)->descriptor);
+ op_def_1.dst_tensors.push_back(_tensor_reg->getClTensorReserver(ofm_index)->descriptor);
+
+ // - ReLU6: clip = 6, alpha = 0
+ ReLUAttributes attr_1;
+ attr_1.clip = 6;
+ attr_1.alpha = 0;
+ gpu_op_1 = SelectReLU(attr_1, op_def_1);
+
+ gpu_op_1->SetSrc(new_tensor.get(), 0);
+ gpu_op_1->SetDst(ofm_tensor->handle(), 0);
+ fn->add_operation(std::move(gpu_op_1));
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("gpu_cl KernelGenerator : Not supported operation yet");
+ }
+ }
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::ElementwiseActivation &node)
+{
+ std::unique_ptr<GPUOperation> gpu_op;
+ auto fn = std::make_unique<ClFunction>();
+
+ switch (node.param().op_type)
+ {
+ case ir::operation::ElementwiseActivation::Type::LEAKY_RELU:
+ case ir::operation::ElementwiseActivation::Type::RELU:
+ {
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{
+ node.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT)};
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+ auto output_tensor = _tensor_reg->getClTensor(output_index);
+ auto input_tensor = _tensor_reg->getClTensor(input_index);
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output_index)->descriptor);
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(input_index)->descriptor);
+
+ ReLUAttributes attr;
+ if (ir::operation::ElementwiseActivation::Type::LEAKY_RELU == node.param().op_type)
+ {
+ attr.alpha = node.param().alpha;
+ attr.clip = 0;
+ }
+ else
+ {
+ attr.alpha = node.param().beta;
+ attr.clip = node.param().alpha;
+ }
+ gpu_op = SelectReLU(attr, op_def);
+ gpu_op->SetSrc(input_tensor->handle(), ir::operation::ElementwiseActivation::Input::INPUT);
+ gpu_op->SetDst(output_tensor->handle(), 0);
+ fn->configure(_creation_context);
+ fn->add_operation(std::move(gpu_op));
+
+ _return_fn = std::move(fn);
+ break;
+ }
+ default:
+ throw std::runtime_error("gpu_cl KernelGenerator : Not supported operation yet");
+ }
+}
+
+void KernelGenerator::visit(const ir::operation::Pool2D &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(input_index)->descriptor);
+ auto input_shape = _tensor_reg->getClTensorReserver(input_index)->shape;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output_index)->descriptor);
+
+ const auto kh = node.param().kh;
+ const auto kw = node.param().kw;
+ const auto stride = node.param().stride;
+ const auto op_type = convertPoolType(node.param().op_type);
+
+ Pooling2DAttributes attributes;
+ attributes.type = op_type;
+ attributes.kernel = HW(kh > 0 ? kh : 1, kw > 0 ? kw : 1);
+ attributes.strides =
+ HW(stride.vertical > 0 ? stride.vertical : 1, stride.horizontal > 0 ? stride.horizontal : 1);
+
+ if (node.param().padding.type == ir::PaddingType::SAME)
+ {
+ attributes.padding = CalculateSamePadding(input_shape, attributes);
+ }
+ else
+ {
+ attributes.padding.prepended = HW(0, 0);
+ attributes.padding.appended = HW(0, 0);
+ }
+
+ auto fn = std::make_unique<ClFunction>();
+ std::unique_ptr<GPUOperation> gpu_op;
+ gpu_op = SelectPooling(attributes, op_def);
+
+ auto input_tensor = _tensor_reg->getClTensor(input_index);
+ auto output_tensor = _tensor_reg->getClTensor(output_index);
+
+ gpu_op->SetSrc(input_tensor->handle(), ir::operation::Pool2D::Input::INPUT);
+ gpu_op->SetDst(output_tensor->handle(), 0);
+
+ fn->configure(_creation_context);
+ fn->add_operation(std::move(gpu_op));
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::Reshape &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::Reshape::Input::INPUT)};
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(input_index)->descriptor);
+ auto input_shape = _tensor_reg->getClTensorReserver(input_index)->shape;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output_index)->descriptor);
+ auto output_shape = _tensor_reg->getClTensorReserver(output_index)->shape;
+
+ ReshapeAttributes attr;
+ attr.new_shape = output_shape;
+
+ auto fn = std::make_unique<ClFunction>();
+ std::unique_ptr<GPUOperation> gpu_op;
+ const int src_channels = input_shape.c;
+ SelectReshape(src_channels, attr.new_shape.c, op_def, &gpu_op);
+
+ auto input_tensor = _tensor_reg->getClTensor(input_index);
+ auto output_tensor = _tensor_reg->getClTensor(output_index);
+ gpu_op->SetSrc(input_tensor->handle(), ir::operation::Reshape::Input::INPUT);
+ gpu_op->SetDst(output_tensor->handle(), 0);
+
+ fn->configure(_creation_context);
+ fn->add_operation(std::move(gpu_op));
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::Softmax &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::Softmax::Input::INPUT)};
+
+ const auto beta = node.param().beta;
+
+ if (beta != 1.0)
+ {
+ throw std::runtime_error("Softmax.beta != 1 is not supported in gpu_cl");
+ }
+
+ OperationDef op_def;
+ op_def.precision = CalculationsPrecision::F32;
+
+ op_def.dst_tensors.push_back(_tensor_reg->getClTensorReserver(output_index)->descriptor);
+
+ op_def.src_tensors.push_back(_tensor_reg->getClTensorReserver(input_index)->descriptor);
+ auto input_shape = _tensor_reg->getClTensorReserver(input_index)->shape;
+
+ auto fn = std::make_unique<ClFunction>();
+
+ std::unique_ptr<GPUOperation> gpu_op;
+ SelectSoftmax(input_shape, op_def, &gpu_op);
+ auto output_tensor = _tensor_reg->getClTensor(output_index);
+ auto input_tensor = _tensor_reg->getClTensor(input_index);
+
+ gpu_op->SetSrc(input_tensor->handle(), ir::operation::Softmax::Input::INPUT);
+ gpu_op->SetDst(output_tensor->handle(), 0);
+
+ fn->configure(_creation_context);
+ fn->add_operation(std::move(gpu_op));
+
+ _return_fn = std::move(fn);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_KERNEL_GENERATOR_H__
+#define __ONERT_BACKEND_GPU_CL_KERNEL_GENERATOR_H__
+
+#include "ClTensorRegistry.h"
+#include "backend/basic/TensorRegistry.h"
+#include "TensorBuilder.h"
+#include "TensorManager.h"
+
+#include <backend/CustomKernelBuilder.h>
+#include <backend/basic/KernelGeneratorBase.h>
+#include <ir/Operands.h>
+#include <ir/Operations.h>
+#include <ir/Operations.Include.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class KernelGenerator : public basic::KernelGeneratorBase
+{
+public:
+ KernelGenerator(const ir::Graph &graph, const std::shared_ptr<TensorBuilder> &tensor_builder,
+ const std::shared_ptr<ClTensorRegistry<TensorManager>> &tensor_reg,
+ const std::shared_ptr<CreationContext> &creation_context);
+
+ std::unique_ptr<exec::FunctionSequence> generate(ir::OperationIndex ind) override;
+
+private:
+ void visit(const ir::operation::BinaryArithmetic &) override;
+ void visit(const ir::operation::Conv2D &) override;
+ void visit(const ir::operation::DepthwiseConv2D &) override;
+ void visit(const ir::operation::ElementwiseActivation &) override;
+ void visit(const ir::operation::Pool2D &) override;
+ void visit(const ir::operation::Reshape &) override;
+ void visit(const ir::operation::Softmax &) override;
+
+private:
+ const ir::Operands &_ctx;
+ const ir::Operations &_operations_ctx;
+ ir::Layout _current_layout;
+ std::shared_ptr<TensorBuilder> _tensor_builder;
+ std::shared_ptr<ClTensorRegistry<TensorManager>> _tensor_reg;
+ std::shared_ptr<CreationContext> _creation_context;
+ ir::OperandIndexMap<std::shared_ptr<Tensor>> _new_tensors;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_KERNEL_GENERATOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_PARENT_INFO_H__
+#define __ONERT_BACKEND_PARENT_INFO_H__
+
+#include <ir/Index.h>
+#include <ir/Coordinates.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+/**
+ * @brief Struct to represent parent operand in child operand
+ */
+struct ParentInfo
+{
+ ir::OperandIndex parent;
+ ir::Layout frontend_layout;
+ ir::Coordinates coordinates;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_ACL_COMMON_PARENT_INFO_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_TENSOR_BUILDER_H__
+#define __ONERT_BACKEND_GPU_CL_TENSOR_BUILDER_H__
+
+#include <backend/basic/TensorBuilder.h>
+#include "operand/ICLTensor.h"
+#include "operand/CLTensor.h"
+#include "ClTensorBuilder.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+using TensorBuilder = ClTensorBuilder<operand::ICLTensor, operand::CLTensor>;
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_TENSOR_BUILDER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_CL_TENSOR_MANAGER_H__
+#define __ONERT_BACKEND_CL_TENSOR_MANAGER_H__
+
+#include "ClMemoryManager.h"
+#include "ClTensorManager.h"
+#include "open_cl/ClContext.h"
+#include "operand/CLTensor.h"
+#include "operand/ICLTensor.h"
+#include "util/logging.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+using MemoryManager = ClMemoryManager<operand::ICLTensor, operand::CLTensor>;
+
+using TensorManager = ClTensorManager<operand::ICLTensor, operand::CLTensor>;
+
+inline TensorManager *createTensorManager(CLContext *context)
+{
+ VERBOSE(createTensorManager) << "ClTensorManager" << std::endl;
+ return new TensorManager(new MemoryManager(context), new MemoryManager(context));
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_ACL_CL_TENSOR_MANAGER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Backend.h"
+
+#include <util/logging.h>
+
+extern "C" {
+onert::backend::Backend *onert_backend_create()
+{
+ VERBOSE(onert_backend_create) << "'gpu_cl' loaded\n";
+ return new onert::backend::gpu_cl::Backend;
+}
+
+void onert_backend_destroy(onert::backend::Backend *backend)
+{
+ VERBOSE(onert_backend_destroy) << "'gpu_cl' unloaded\n";
+ delete backend;
+}
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_ACCESS_TYPE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_ACCESS_TYPE_H__
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+enum class AccessType
+{
+ UNKNOWN,
+ READ,
+ WRITE,
+ READ_WRITE,
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_ACCESS_TYPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Api.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+struct ObjectTypeGetter
+{
+ ObjectType operator()(absl::monostate) const { return ObjectType::UNKNOWN; }
+ ObjectType operator()(OpenClBuffer) const { return ObjectType::OPENCL_BUFFER; }
+ ObjectType operator()(OpenClTexture) const { return ObjectType::OPENCL_TEXTURE; }
+ ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
+};
+
+struct ObjectValidityChecker
+{
+ bool operator()(absl::monostate) const { return false; }
+ bool operator()(OpenClBuffer obj) const { return obj.memobj; }
+ bool operator()(OpenClTexture obj) const { return obj.memobj; }
+ bool operator()(CpuMemory obj) const
+ {
+ return obj.data != nullptr && obj.size_bytes > 0 &&
+ (data_type == DataType::UNKNOWN || obj.size_bytes % SizeOf(data_type) == 0);
+ }
+ DataType data_type;
+};
+
+} // namespace
+
+bool IsValid(const ObjectDef &def)
+{
+ return def.data_type != DataType::UNKNOWN && def.data_layout != DataLayout::UNKNOWN &&
+ def.object_type != ObjectType::UNKNOWN;
+}
+
+ObjectType GetType(const TensorObject &object) { return absl::visit(ObjectTypeGetter{}, object); }
+
+bool IsValid(const TensorObjectDef &def) { return IsValid(def.object_def); }
+
+bool IsValid(const TensorObjectDef &def, const TensorObject &object)
+{
+ return GetType(object) == def.object_def.object_type &&
+ absl::visit(ObjectValidityChecker{def.object_def.data_type}, object);
+}
+
+bool IsObjectPresent(ObjectType type, const TensorObject &obj)
+{
+ switch (type)
+ {
+ case ObjectType::CPU_MEMORY:
+ return absl::holds_alternative<CpuMemory>(obj);
+ case ObjectType::OPENCL_BUFFER:
+ return absl::holds_alternative<OpenClBuffer>(obj);
+ case ObjectType::OPENCL_TEXTURE:
+ return absl::holds_alternative<OpenClTexture>(obj);
+ case ObjectType::UNKNOWN:
+ return false;
+ }
+ return false;
+}
+
+uint32_t NumElements(const TensorObjectDef &def)
+{
+ const auto &d = def.dimensions;
+ switch (def.object_def.data_layout)
+ {
+ case DataLayout::BHWC:
+ return d.product();
+ case DataLayout::HWDC4:
+ case DataLayout::HDWC4:
+ case DataLayout::DHWC4:
+ return d.b * d.h * d.w * AlignByN(d.c, 4);
+ case DataLayout::UNKNOWN:
+ return 0;
+ }
+ return 0;
+}
+
+int GetPosition(const InferenceOptions &options, InferencePriority p)
+{
+ if (options.priority1 == p)
+ return 1;
+ if (options.priority2 == p)
+ return 2;
+ if (options.priority3 == p)
+ return 3;
+ return 4; // least important
+}
+
+PriorityImportance GetRelativeImportance(const InferenceOptions &options, InferencePriority p1,
+ InferencePriority p2)
+{
+ int p1_position = GetPosition(options, p1);
+ int p2_position = GetPosition(options, p2);
+ if (p1_position == p2_position)
+ return PriorityImportance::UNKNOWN;
+ return p1_position < p2_position ? PriorityImportance::HIGHER : PriorityImportance::LOWER;
+}
+
+bool IsValid(const InferenceOptions &options)
+{
+ if (options.usage == InferenceUsage::UNKNOWN)
+ {
+ return false;
+ }
+ if (options.priority1 == InferencePriority::UNKNOWN ||
+ options.priority2 == InferencePriority::UNKNOWN ||
+ options.priority3 == InferencePriority::UNKNOWN)
+ {
+ return false;
+ }
+ if (options.priority1 == InferencePriority::AUTO)
+ {
+ return false;
+ }
+ if (options.priority2 == InferencePriority::AUTO && options.priority3 != InferencePriority::AUTO)
+ {
+ return false;
+ }
+ if (options.priority1 == options.priority2 || options.priority1 == options.priority3)
+ {
+ return false;
+ }
+ if (options.priority2 == options.priority3 && options.priority2 != InferencePriority::AUTO)
+ {
+ return false;
+ }
+ return true;
+}
+
+// Implementation note: this resolution logic is shared between GL and CL
+// backends, but they might have own logic. Thus, the function is defined
+// here just for code re-use purposes.
+void ResolveAutoPriority(InferenceOptions *options)
+{
+ // priority1 can not be AUTO as it would make options invalid.
+ if (options->priority2 == InferencePriority::AUTO)
+ {
+ switch (options->priority1)
+ {
+ case InferencePriority::MIN_LATENCY:
+ options->priority2 = InferencePriority::MIN_MEMORY_USAGE;
+ options->priority3 = InferencePriority::MAX_PRECISION;
+ return;
+ case InferencePriority::MIN_MEMORY_USAGE:
+ options->priority2 = InferencePriority::MAX_PRECISION;
+ options->priority3 = InferencePriority::MIN_LATENCY;
+ return;
+ case InferencePriority::MAX_PRECISION:
+ options->priority2 = InferencePriority::MIN_LATENCY;
+ options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
+ return;
+ case InferencePriority::UNKNOWN:
+ case InferencePriority::AUTO:
+ // Invalid and unreachable option.
+ return;
+ }
+ }
+
+ if (options->priority3 == InferencePriority::AUTO)
+ {
+ // Simply add missing priority
+ if (GetPosition(*options, InferencePriority::MIN_LATENCY) == 4)
+ {
+ options->priority3 = InferencePriority::MIN_LATENCY;
+ }
+ else if (GetPosition(*options, InferencePriority::MAX_PRECISION) == 4)
+ {
+ options->priority3 = InferencePriority::MAX_PRECISION;
+ }
+ else if (GetPosition(*options, InferencePriority::MIN_MEMORY_USAGE) == 4)
+ {
+ options->priority3 = InferencePriority::MIN_MEMORY_USAGE;
+ }
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_API_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_API_H__
+
+// Usage example:
+//
+// // Builder is created from a model using GPU-specific parameters.
+// std::unique_ptr<InferenceBuilder> builder = ...;
+//
+// // input data is coming from a texture
+// // output data goes to CPU
+// builder->SetInputObjectDef(0, {DataType::FLOAT16, DataLayout::PHWC4,
+// ObjectType::OPENGL_TEXTURE, true});
+// builder->SetOutputObjectDef(0, {DataType::FLOAT32, DataLayout::BHWC,
+// ObjectType::CPU_MEMORY, false});
+// std::unique_ptr<InferenceRunner> runner;
+// RETURN_IF_ERROR(builder->Build(&runner)); // may take significant time.
+// RETURN_IF_ERROR(
+// runner->SetInputObject(0, OpenGlTexture{texture_ud, texture_format}));
+// RETURN_IF_ERROR(runner->Run());
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "absl/types/span.h"
+#include "absl/types/variant.h"
+#include "DataType.h"
+#include "Status.h"
+#include "Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// Common abbreviations:
+// B - batch
+// H - height
+// W - width
+// C - channels
+// D - depth := DivideRoundUp(C, 4)
+// C4 - is the constant = 4.
+enum class DataLayout
+{
+ UNKNOWN,
+ BHWC,
+ DHWC4,
+ HWDC4,
+ HDWC4,
+};
+
+enum class ObjectType
+{
+ UNKNOWN,
+ CPU_MEMORY,
+ OPENCL_TEXTURE,
+ OPENCL_BUFFER,
+};
+
+struct OpenClBuffer
+{
+ OpenClBuffer() = default;
+ explicit OpenClBuffer(cl_mem new_memobj) : memobj(new_memobj) {}
+
+ cl_mem memobj = nullptr;
+};
+
+struct OpenClTexture
+{
+ OpenClTexture() = default;
+ explicit OpenClTexture(cl_mem new_memobj) : memobj(new_memobj) {}
+
+ cl_mem memobj = nullptr;
+ // TODO(akulik): should it specify texture format?
+};
+
+struct CpuMemory
+{
+ CpuMemory() = default;
+ CpuMemory(void *new_data, size_t new_size_bytes) : data(new_data), size_bytes(new_size_bytes) {}
+
+ void *data = nullptr;
+ size_t size_bytes = 0;
+};
+
+template <typename T> inline CpuMemory MakeCpuMemory(absl::Span<T> t)
+{
+ CpuMemory m;
+ m.data = t.data();
+ m.size_bytes = t.size() * sizeof(T);
+ return m;
+}
+
+template <typename T> inline CpuMemory MakeReadableCpuMemory(absl::Span<const T> t)
+{
+ CpuMemory m;
+ m.data = const_cast<T *>(t.data());
+ m.size_bytes = t.size() * sizeof(T);
+ return m;
+}
+
+// Defines object representation.
+struct ObjectDef
+{
+ DataType data_type = DataType::UNKNOWN;
+ DataLayout data_layout = DataLayout::UNKNOWN;
+ ObjectType object_type = ObjectType::UNKNOWN;
+
+ // If true, then object is managed externally and needs to be provided to
+ // InferenceRunner by a user before running inference.
+ //
+ // User-provided objects will not be re-used internally for any purpose to
+ // lower overall memory usage.
+ bool user_provided = false;
+
+ bool operator==(const ObjectDef &other) const
+ {
+ return data_type == other.data_type && data_layout == other.data_layout &&
+ object_type == other.object_type && user_provided == other.user_provided;
+ }
+};
+
+bool IsValid(const ObjectDef &def);
+
+struct Dimensions
+{
+ Dimensions() : b(1), h(1), w(1), c(1) {}
+
+ Dimensions(int32_t batch, int32_t height, int32_t width, int32_t channels)
+ : b(batch), h(height), w(width), c(channels)
+ {
+ }
+
+ int32_t d() const { return DivideRoundUp(c, 4); }
+
+ int32_t product() const { return b * h * w * c; }
+
+ bool operator==(const Dimensions &other) const
+ {
+ return b == other.b && h == other.h && w == other.w && c == other.c;
+ }
+
+ int32_t b;
+ int32_t h;
+ int32_t w;
+ int32_t c;
+};
+
+// Connects tensor shape with corresponding object definition.
+struct TensorObjectDef
+{
+ // Dimensions semantic is defined by corresponding DataLayout.
+ Dimensions dimensions;
+ ObjectDef object_def;
+
+ bool operator==(const TensorObjectDef &other) const
+ {
+ return dimensions == other.dimensions && object_def == other.object_def;
+ }
+};
+
+// @return true if tensor object def is defined.
+bool IsValid(const TensorObjectDef &def);
+
+// @return the number of elements in a tensor object.
+uint32_t NumElements(const TensorObjectDef &def);
+
+using TensorObject = absl::variant<absl::monostate, CpuMemory, OpenClBuffer, OpenClTexture>;
+
+// @return true if object is set and corresponding values are defined.
+bool IsValid(const TensorObjectDef &def, const TensorObject &object);
+
+ObjectType GetType(const TensorObject &object);
+
+// @return true if corresponding object is set for the given type
+bool IsObjectPresent(ObjectType type, const TensorObject &obj);
+
+class InferenceRunner;
+
+// Allows to inspect and change input and output definitions before a graph is
+// prepared for the inference.
+class InferenceBuilder
+{
+public:
+ virtual ~InferenceBuilder() {}
+
+ // Returns inference graph inputs and outputs definitions.
+ virtual std::vector<TensorObjectDef> inputs() const = 0;
+ virtual std::vector<TensorObjectDef> outputs() const = 0;
+
+ // Sets new shape for the input if underlying implementation and graph
+ // structure allows dynamic tensors.
+ virtual absl::Status SetInputShape(int index, const Dimensions &dimensions) = 0;
+
+ // Updates object definitions for the given index. Implementation may allow
+ // to use different layouts and/or data type conversions between objects
+ // defined in a graph and given objects, for example:
+ // input '0' is DataType::FLOAT32, DataLayout::BHWC.
+ // A user, however, has an input in DataType::FLOAT16, DataLayout::PHWC4.
+ // An implementation may allow this transformation to happen automatically
+ // under the hood.
+ virtual absl::Status SetInputObjectDef(int index, ObjectDef def) = 0;
+ virtual absl::Status SetOutputObjectDef(int index, ObjectDef def) = 0;
+ virtual absl::Status SetAllInputObjectDefsTo(ObjectDef def)
+ {
+ auto input_defs = inputs();
+ for (size_t i = 0; i < input_defs.size(); ++i)
+ {
+ RETURN_IF_ERROR(SetInputObjectDef(i, def));
+ }
+ return absl::OkStatus();
+ }
+ virtual absl::Status SetAllOutputObjectDefsTo(ObjectDef def)
+ {
+ auto output_defs = outputs();
+ for (size_t i = 0; i < output_defs.size(); ++i)
+ {
+ RETURN_IF_ERROR(SetOutputObjectDef(i, def));
+ }
+ return absl::OkStatus();
+ }
+
+ // Creates new instance of the inference runner. InferenceBuilder stays valid
+ // and could be used to create another inference runner if needed.
+ //
+ // This method may take significant time to prepare new inference runner. For
+ // example, it may require to compile OpenGL shaders.
+ virtual absl::Status Build(std::unique_ptr<InferenceRunner> *runner) = 0;
+};
+
+// Runs prepared inference. Every object marked as external needs to be set
+// prior calling Run method.
+class InferenceRunner
+{
+public:
+ virtual ~InferenceRunner() {}
+
+ // Returns inference graph inputs and outputs definitions.
+ virtual std::vector<TensorObjectDef> inputs() const = 0;
+ virtual std::vector<TensorObjectDef> outputs() const = 0;
+
+ // Getters provide access to underlying objects for the given index.
+ // Setters allow to set or change external object for the given index. Note,
+ // object need to match object definition set before in InferenceBuilder.
+
+ virtual absl::Status GetInputObject(int index, TensorObject *object) = 0;
+ virtual absl::Status GetOutputObject(int index, TensorObject *object) = 0;
+ virtual absl::Status SetInputObject(int index, TensorObject object) = 0;
+ virtual absl::Status SetOutputObject(int index, TensorObject object) = 0;
+
+ virtual absl::Status Run() = 0;
+};
+
+// Encapsulated compilation/runtime tradeoffs.
+enum class InferenceUsage
+{
+ UNKNOWN,
+
+ // InferenceRunner will be used only once. Therefore, it is important to
+ // minimize bootstrap time as well.
+ FAST_SINGLE_ANSWER,
+
+ // Prefer maximizing the throughput. Same inference runner will be used
+ // repeatedly on different inputs.
+ SUSTAINED_SPEED,
+};
+
+// Defines aspects to control while instantiating a runner.
+enum class InferencePriority
+{
+ UNKNOWN,
+
+ AUTO,
+
+ MIN_LATENCY,
+
+ MAX_PRECISION,
+
+ MIN_MEMORY_USAGE,
+};
+
+struct InferenceOptions
+{
+ InferenceUsage usage = InferenceUsage::SUSTAINED_SPEED;
+
+ // Ordered priorities provide better understanding of desired semantics,
+ // where priority(n) is more important than priority(n+1).
+ // AUTO priority is needed when a single priority is the most important
+ // factor. For example, priority1 = InferencePriority::MIN_LATENCY and leaving
+ // everything else to AUTO would result in configuration that achieves maximum
+ // performance.
+ //
+ // AUTO priority can only be used when higher priorities are fully specified.
+ // For example:
+ // VALID: priority1 = MIN_LATENCY, priority2 = AUTO, priority3 = AUTO
+ // VALID: priority1 = MIN_LATENCY, priority2 = MAX_PRECISION,
+ // priority3 = AUTO
+ // INVALID: priority1 = AUTO, priority2 = MIN_LATENCY, priority3 = AUTO
+ // INVALID: priority1 = MIN_LATENCY, priority2 = AUTO,
+ // priority3 = MAX_PRECISION
+ // Invalid priorities will result in error.
+ InferencePriority priority1 = InferencePriority::MAX_PRECISION;
+
+ InferencePriority priority2 = InferencePriority::AUTO;
+
+ InferencePriority priority3 = InferencePriority::AUTO;
+};
+
+// Returns a position number for the priority. If priority is missing,
+// then it it would return 'max num priorities + 1'.
+int GetPosition(const InferenceOptions &options, InferencePriority p);
+
+// Return true if options are valid.
+bool IsValid(const InferenceOptions &options);
+
+// Resolves AUTO priorities and specifies them explicitly.
+// Note, no-one should assume that these mappings will not change.
+// Technically this function is declared here for code re-use purposes and
+// by no means it should be treated as canonical way to resolve AUTO.
+void ResolveAutoPriority(InferenceOptions *options);
+
+enum class PriorityImportance
+{
+ UNKNOWN,
+ HIGHER,
+ LOWER,
+};
+
+// If both p1 and p2 are not present in options, return UNKNOWN
+// If p1 is present, but p2 is not, return HIGHER
+// If p2 is present, but p1 is not, return LOWER
+// If both are present, and p1 is more important, return HIGHER, otherwise,
+// LOWER.
+PriorityImportance GetRelativeImportance(const InferenceOptions &options, InferencePriority p1,
+ InferencePriority p2);
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_API_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Arguments.h"
+
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_replace.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/substitute.h"
+
+#include "AccessType.h"
+#include "TensorType.h"
+#include "DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+namespace
+{
+
+bool IsWordSymbol(char symbol) { return absl::ascii_isalnum(symbol) || symbol == '_'; }
+
+std::string GetNextWord(const std::string &code, size_t first_position)
+{
+ size_t pos = first_position;
+ char t = code[pos];
+ while (IsWordSymbol(t))
+ {
+ pos++;
+ t = code[pos];
+ }
+ return code.substr(first_position, pos - first_position);
+}
+
+size_t FindEnclosingBracket(const std::string &text, size_t first_pos, char bracket)
+{
+ const std::map<char, char> brackets = {
+ {'(', ')'},
+ {'{', '}'},
+ {'[', ']'},
+ {'<', '>'},
+ };
+ char b_open = bracket;
+ auto it = brackets.find(b_open);
+ if (it == brackets.end())
+ {
+ return -1;
+ }
+ char b_close = it->second;
+ size_t pos = first_pos;
+ int opened = 1;
+ int closed = 0;
+ while (opened != closed && pos < text.size())
+ {
+ if (text[pos] == b_open)
+ {
+ opened++;
+ }
+ else if (text[pos] == b_close)
+ {
+ closed++;
+ }
+ pos++;
+ }
+ if (opened == closed)
+ {
+ return pos;
+ }
+ else
+ {
+ return -1;
+ }
+}
+
+absl::Status ParseArgsInsideBrackets(const std::string &text, size_t open_bracket_pos,
+ size_t *close_bracket_pos, std::vector<std::string> *args)
+{
+ *close_bracket_pos = FindEnclosingBracket(text, open_bracket_pos + 1, text[open_bracket_pos]);
+ if (*close_bracket_pos == static_cast<size_t>(-1))
+ {
+ return absl::NotFoundError("Not found enclosing bracket");
+ }
+ std::string str_args =
+ text.substr(open_bracket_pos + 1, *close_bracket_pos - open_bracket_pos - 2);
+ std::vector<absl::string_view> words = absl::StrSplit(str_args, ',');
+ args->reserve(words.size());
+ for (const auto &word : words)
+ {
+ absl::string_view arg = absl::StripAsciiWhitespace(word);
+ if (!arg.empty())
+ {
+ args->push_back(std::string(arg));
+ }
+ }
+ return absl::OkStatus();
+}
+
+void ReplaceAllWords(const std::string &old_word, const std::string &new_word, std::string *str)
+{
+ size_t position = str->find(old_word);
+ while (position != std::string::npos)
+ {
+ char prev = position == 0 ? '.' : (*str)[position - 1];
+ char next = position + old_word.size() < str->size() ? (*str)[position + old_word.size()] : '.';
+ if (IsWordSymbol(prev) || IsWordSymbol(next))
+ {
+ position = str->find(old_word, position + 1);
+ continue;
+ }
+ str->replace(position, old_word.size(), new_word);
+ position = str->find(old_word, position + new_word.size());
+ }
+}
+
+std::string RenameArg(const std::vector<std::string> &object_names, const std::string &postfix,
+ const std::string &arg_name)
+{
+ for (const auto &object_name : object_names)
+ {
+ if (absl::StartsWith(arg_name, object_name) && arg_name.size() > object_name.size() &&
+ arg_name[object_name.size()] == '_')
+ {
+ return object_name + postfix +
+ arg_name.substr(object_name.size(), arg_name.size() - object_name.size());
+ }
+ }
+ return arg_name + postfix;
+}
+
+void AppendArgument(const std::string &arg, std::string *args)
+{
+ if (!args->empty())
+ {
+ absl::StrAppend(args, ",\n ");
+ }
+ absl::StrAppend(args, arg);
+}
+
+std::string GetImageModifier(AccessType access)
+{
+ switch (access)
+ {
+ case AccessType::READ:
+ return "__read_only";
+ case AccessType::WRITE:
+ return "__write_only";
+ case AccessType::READ_WRITE:
+ return "__read_write";
+ default:
+ throw std::runtime_error("Invalid AccessType");
+ }
+}
+
+std::string GetDefaultSamplers(const DeviceInfo &device_info)
+{
+ std::string result;
+ result += "__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | "
+ "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
+ if (device_info.IsAdreno3xx())
+ {
+ // Unfortunately, CLK_ADDRESS_CLAMP is very slow on Adreno3xx and
+ // we can observe huge register overhead when compared to other modes.
+
+ // While using CLK_ADDRESS_NONE with out-of-range image coordinates is
+ // undefined in the OpenCL specification, we have observed that
+ // CLK_ADDRESS_NONE works like CLK_ADDRESS_CLAMP for out-of-range image
+ // coordinates for RGBA F16/F32 textures on Adreno3xx devices. Using
+ // CLK_ADDRESS_NONE is significantly faster than CLK_ADDRESS_CLAMP on Adreno
+ // 3xx.
+ result += "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
+ "CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n";
+ }
+ else
+ {
+ result += "__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | "
+ "CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n";
+ }
+
+ return result;
+}
+
+} // namespace
+
+// Static
+constexpr char Arguments::kArgsPrefix[];
+
+Arguments::Arguments(Arguments &&args)
+ : int_values_(std::move(args.int_values_)),
+ shared_int4s_data_(std::move(args.shared_int4s_data_)),
+ float_values_(std::move(args.float_values_)),
+ shared_float4s_data_(std::move(args.shared_float4s_data_)), buffers_(std::move(args.buffers_)),
+ images2d_(std::move(args.images2d_)), image2d_arrays_(std::move(args.image2d_arrays_)),
+ images3d_(std::move(args.images3d_)), image_buffers_(std::move(args.image_buffers_)),
+ custom_memories_(std::move(args.custom_memories_)), object_refs_(std::move(args.object_refs_)),
+ objects_(std::move(args.objects_))
+{
+}
+Arguments &Arguments::operator=(Arguments &&args)
+{
+ if (this != &args)
+ {
+ int_values_ = std::move(args.int_values_);
+ shared_int4s_data_ = std::move(args.shared_int4s_data_);
+ float_values_ = std::move(args.float_values_);
+ shared_float4s_data_ = std::move(args.shared_float4s_data_);
+ buffers_ = std::move(args.buffers_);
+ images2d_ = std::move(args.images2d_);
+ image2d_arrays_ = std::move(args.image2d_arrays_);
+ images3d_ = std::move(args.images3d_);
+ image_buffers_ = std::move(args.image_buffers_);
+ custom_memories_ = std::move(args.custom_memories_);
+ object_refs_ = std::move(args.object_refs_);
+ objects_ = std::move(args.objects_);
+ }
+ return *this;
+}
+
+void Arguments::AddFloat(const std::string &name, float value)
+{
+ float_values_[name].value = value;
+}
+void Arguments::AddInt(const std::string &name, int value) { int_values_[name].value = value; }
+void Arguments::AddBuffer(const std::string &name, const GPUBufferDescriptor &desc)
+{
+ buffers_[name] = desc;
+}
+void Arguments::AddImage2D(const std::string &name, const GPUImage2DDescriptor &desc)
+{
+ images2d_[name] = desc;
+}
+
+void Arguments::AddImage2DArray(const std::string &name, const GPUImage2DArrayDescriptor &desc)
+{
+ image2d_arrays_[name] = desc;
+}
+
+void Arguments::AddImage3D(const std::string &name, const GPUImage3DDescriptor &desc)
+{
+ images3d_[name] = desc;
+}
+
+void Arguments::AddImageBuffer(const std::string &name, const GPUImageBufferDescriptor &desc)
+{
+ image_buffers_[name] = desc;
+}
+
+void Arguments::AddCustomMemory(const std::string &name, const GPUCustomMemoryDescriptor &desc)
+{
+ custom_memories_[name] = desc;
+}
+
+void Arguments::AddObjectRef(const std::string &name, AccessType access_type,
+ GPUObjectDescriptorPtr &&descriptor_ptr)
+{
+ descriptor_ptr->SetAccess(access_type);
+ object_refs_[name] = {std::move(descriptor_ptr)};
+}
+
+void Arguments::AddObject(const std::string &name, GPUObjectDescriptorPtr &&descriptor_ptr)
+{
+ descriptor_ptr->SetAccess(AccessType::READ);
+ objects_[name] = {nullptr, std::move(descriptor_ptr)};
+}
+
+void Arguments::AddGPUResources(const std::string &name, const GPUResources &resources)
+{
+ for (const auto &r : resources.ints)
+ {
+ AddInt(absl::StrCat(name, "_", r));
+ }
+ for (const auto &r : resources.floats)
+ {
+ AddFloat(absl::StrCat(name, "_", r));
+ }
+ for (const auto &r : resources.buffers)
+ {
+ AddBuffer(absl::StrCat(name, "_", r.first), r.second);
+ }
+ for (const auto &r : resources.images2d)
+ {
+ AddImage2D(absl::StrCat(name, "_", r.first), r.second);
+ }
+ for (const auto &r : resources.image2d_arrays)
+ {
+ AddImage2DArray(absl::StrCat(name, "_", r.first), r.second);
+ }
+ for (const auto &r : resources.images3d)
+ {
+ AddImage3D(absl::StrCat(name, "_", r.first), r.second);
+ }
+ for (const auto &r : resources.image_buffers)
+ {
+ AddImageBuffer(absl::StrCat(name, "_", r.first), r.second);
+ }
+ for (const auto &r : resources.custom_memories)
+ {
+ AddCustomMemory(absl::StrCat(name, "_", r.first), r.second);
+ }
+}
+
+absl::Status Arguments::SetInt(const std::string &name, int value)
+{
+ auto it = int_values_.find(name);
+ if (it == int_values_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No int argument with name - ", name));
+ }
+ it->second.value = value;
+ if (it->second.active)
+ {
+ shared_int4s_data_[it->second.offset] = value;
+ }
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetFloat(const std::string &name, float value)
+{
+ auto it = float_values_.find(name);
+ if (it == float_values_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No float argument with name - ", name));
+ }
+ it->second.value = value;
+ if (it->second.active)
+ {
+ shared_float4s_data_[it->second.offset] = value;
+ }
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetImage2D(const std::string &name, cl_mem memory)
+{
+ auto it = images2d_.find(name);
+ if (it == images2d_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No image2D argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetBuffer(const std::string &name, cl_mem memory)
+{
+ auto it = buffers_.find(name);
+ if (it == buffers_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No buffer argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetImage2DArray(const std::string &name, cl_mem memory)
+{
+ auto it = image2d_arrays_.find(name);
+ if (it == image2d_arrays_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No image2D array argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetImage3D(const std::string &name, cl_mem memory)
+{
+ auto it = images3d_.find(name);
+ if (it == images3d_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No image3D argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetImageBuffer(const std::string &name, cl_mem memory)
+{
+ auto it = image_buffers_.find(name);
+ if (it == image_buffers_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No image buffer argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetCustomMemory(const std::string &name, cl_mem memory)
+{
+ auto it = custom_memories_.find(name);
+ if (it == custom_memories_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No custom memory argument with name - ", name));
+ }
+ it->second.memory = memory;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::SetObjectRef(const std::string &name, const GPUObject *object)
+{
+ auto it = object_refs_.find(name);
+ if (it == object_refs_.end())
+ {
+ return absl::NotFoundError(absl::StrCat("No object ref with name - ", name));
+ }
+ GPUResourcesWithValue resources;
+ RETURN_IF_ERROR(object->GetGPUResources(it->second.descriptor.get(), &resources));
+ return SetGPUResources(name, resources);
+}
+
+absl::Status Arguments::SetGPUResources(const std::string &name,
+ const GPUResourcesWithValue &resources)
+{
+ for (const auto &r : resources.ints)
+ {
+ RETURN_IF_ERROR(SetInt(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.floats)
+ {
+ RETURN_IF_ERROR(SetFloat(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.buffers)
+ {
+ RETURN_IF_ERROR(SetBuffer(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.images2d)
+ {
+ RETURN_IF_ERROR(SetImage2D(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.image2d_arrays)
+ {
+ RETURN_IF_ERROR(SetImage2DArray(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.images3d)
+ {
+ RETURN_IF_ERROR(SetImage3D(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.image_buffers)
+ {
+ RETURN_IF_ERROR(SetImageBuffer(absl::StrCat(name, "_", r.first), r.second));
+ }
+ for (const auto &r : resources.custom_memories)
+ {
+ RETURN_IF_ERROR(SetCustomMemory(absl::StrCat(name, "_", r.first), r.second));
+ }
+ return absl::OkStatus();
+}
+void Arguments::RenameArgs(const std::string &postfix, std::string *code) const
+{
+ size_t next_position = code->find(kArgsPrefix);
+ while (next_position != std::string::npos)
+ {
+ size_t arg_pos = next_position + strlen(kArgsPrefix);
+ std::string arg_name = GetNextWord(*code, arg_pos);
+ code->replace(arg_pos, arg_name.size(), arg_name + postfix);
+ next_position = code->find(kArgsPrefix, arg_pos + arg_name.size());
+ }
+}
+
+absl::Status Arguments::Merge(Arguments &&args, const std::string &postfix)
+{
+ std::vector<std::string> object_names;
+ object_names.reserve(args.object_refs_.size() + args.objects_.size());
+ for (auto &v : args.object_refs_)
+ {
+ object_names.push_back(v.first);
+ const std::string name = v.first + postfix;
+ if (object_refs_.find(name) != object_refs_.end())
+ {
+ return absl::InvalidArgumentError(
+ absl::StrCat("Object reference name collision. Name - ", name));
+ }
+ object_refs_[name] = {std::move(v.second.descriptor)};
+ }
+ for (auto &v : args.objects_)
+ {
+ object_names.push_back(v.first);
+ const std::string name = v.first + postfix;
+ if (objects_.find(name) != objects_.end())
+ {
+ return absl::InvalidArgumentError(absl::StrCat("Object name collision. Name - ", name));
+ }
+ objects_[name] = {std::move(v.second.obj_ptr), std::move(v.second.descriptor)};
+ }
+ for (const auto &v : args.int_values_)
+ {
+ AddInt(RenameArg(object_names, postfix, v.first), v.second.value);
+ }
+ for (const auto &v : args.float_values_)
+ {
+ AddFloat(RenameArg(object_names, postfix, v.first), v.second.value);
+ }
+ for (const auto &v : args.buffers_)
+ {
+ AddBuffer(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ for (const auto &v : args.images2d_)
+ {
+ AddImage2D(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ for (const auto &v : args.image2d_arrays_)
+ {
+ AddImage2DArray(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ for (const auto &v : args.images3d_)
+ {
+ AddImage3D(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ for (const auto &v : args.image_buffers_)
+ {
+ AddImageBuffer(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ for (const auto &v : args.custom_memories_)
+ {
+ AddCustomMemory(RenameArg(object_names, postfix, v.first), v.second);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::TransformToCLCode(const DeviceInfo &device_info,
+ const std::map<std::string, std::string> &linkables,
+ std::string *code)
+{
+ RETURN_IF_ERROR(AddObjectArgs());
+ RETURN_IF_ERROR(ResolveSelectorsPass(linkables, code));
+ ResolveArgsPass(device_info, code);
+ *code = absl::Substitute(*code, GetListOfArgs());
+ *code = GetDefaultSamplers(device_info) + *code;
+ return absl::OkStatus();
+}
+
+std::string Arguments::GetListOfArgs()
+{
+ std::string result;
+ for (auto &t : buffers_)
+ {
+ const std::string type_name = t.second.data_type == DataType::FLOAT32 ? "float" : "half";
+ std::string attributes;
+ for (const auto &attr : t.second.attributes)
+ {
+ attributes += absl::StrCat(" __attribute__((", attr, "))");
+ }
+ AppendArgument(absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
+ ToCLDataType(t.second.data_type, t.second.element_size), "* ",
+ t.first, attributes),
+ &result);
+ }
+ for (auto &t : image_buffers_)
+ {
+ AppendArgument(
+ absl::StrCat(GetImageModifier(t.second.access_type), " image1d_buffer_t ", t.first), &result);
+ }
+ for (auto &t : images2d_)
+ {
+ AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type), " image2d_t ", t.first),
+ &result);
+ }
+ for (auto &t : image2d_arrays_)
+ {
+ AppendArgument(
+ absl::StrCat(GetImageModifier(t.second.access_type), " image2d_array_t ", t.first), &result);
+ }
+ for (auto &t : images3d_)
+ {
+ AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type), " image3d_t ", t.first),
+ &result);
+ }
+ for (auto &t : custom_memories_)
+ {
+ AppendArgument(absl::StrCat(t.second.type_name, " ", t.first), &result);
+ }
+ for (uint32_t i = 0; i < shared_int4s_data_.size() / 4; ++i)
+ {
+ AppendArgument(absl::StrCat("int4 shared_int4_", i), &result);
+ }
+ for (uint32_t i = 0; i < shared_float4s_data_.size() / 4; ++i)
+ {
+ AppendArgument(absl::StrCat("float4 shared_float4_", i), &result);
+ }
+ return result;
+}
+
+absl::Status Arguments::Bind(cl_kernel kernel, int offset)
+{
+ for (auto &t : buffers_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (auto &t : image_buffers_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (auto &t : images2d_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (auto &t : image2d_arrays_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (auto &t : images3d_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (auto &t : custom_memories_)
+ {
+ const int error_code = clSetKernelArg(kernel, offset, sizeof(cl_mem), &t.second.memory);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (size_t i = 0; i < shared_int4s_data_.size() / 4; ++i)
+ {
+ const int error_code =
+ clSetKernelArg(kernel, offset, sizeof(int32_t) * 4, &shared_int4s_data_[i * 4]);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ for (size_t i = 0; i < shared_float4s_data_.size() / 4; ++i)
+ {
+ const int error_code =
+ clSetKernelArg(kernel, offset, sizeof(int32_t) * 4, &shared_float4s_data_[i * 4]);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ offset, ")"));
+ }
+ offset++;
+ }
+ return absl::OkStatus();
+}
+
+std::string Arguments::AddActiveArgument(const std::string &arg_name, bool)
+{
+ {
+ auto it = int_values_.find(arg_name);
+ if (it != int_values_.end())
+ {
+ int int_index;
+ if (it->second.active)
+ {
+ int_index = it->second.offset;
+ }
+ else
+ {
+ it->second.active = true;
+ it->second.offset = shared_int4s_data_.size();
+ int_index = it->second.offset;
+ shared_int4s_data_.push_back(it->second.value);
+ }
+ std::string index = std::to_string(int_index / 4);
+ std::string postfixes[4] = {"x", "y", "z", "w"};
+ return "shared_int4_" + index + "." + postfixes[int_index % 4];
+ }
+ }
+ {
+ auto it = float_values_.find(arg_name);
+ if (it != float_values_.end())
+ {
+ int float_index;
+ if (it->second.active)
+ {
+ float_index = it->second.offset;
+ }
+ else
+ {
+ it->second.active = true;
+ it->second.offset = shared_float4s_data_.size();
+ float_index = it->second.offset;
+ shared_float4s_data_.push_back(it->second.value);
+ }
+ std::string index = std::to_string(float_index / 4);
+ std::string postfixes[4] = {"x", "y", "z", "w"};
+ return "shared_float4_" + index + "." + postfixes[float_index % 4];
+ }
+ }
+ return arg_name;
+}
+
+void Arguments::ResolveArgsPass(const DeviceInfo &device_info, std::string *code)
+{
+ bool use_f32_for_half_arguments = device_info.IsPowerVR();
+ size_t position = 0;
+ size_t next_position = code->find(kArgsPrefix);
+ while (next_position != std::string::npos)
+ {
+ size_t arg_pos = next_position;
+ next_position += strlen(kArgsPrefix);
+ std::string object_name = GetNextWord(*code, next_position);
+ std::string new_name = AddActiveArgument(object_name, use_f32_for_half_arguments);
+ code->replace(arg_pos, object_name.size() + strlen(kArgsPrefix), new_name);
+ position = arg_pos + new_name.size();
+ next_position = code->find(kArgsPrefix, position);
+ }
+
+ int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4);
+ shared_int4s_data_.resize(shared_int4s_aligned_size);
+ int shared_float4s_aligned_size = AlignByN(shared_float4s_data_.size(), 4);
+ shared_float4s_data_.resize(shared_float4s_aligned_size);
+}
+
+void Arguments::ResolveObjectNames(const std::string &object_name,
+ const std::vector<std::string> &member_names, std::string *code)
+{
+ for (const auto &member_name : member_names)
+ {
+ const std::string new_name = kArgsPrefix + object_name + "_" + member_name;
+ ReplaceAllWords(member_name, new_name, code);
+ }
+}
+
+GPUObjectDescriptor *Arguments::GetObjectDescriptor(const std::string &object_name) const
+{
+ {
+ auto it = object_refs_.find(object_name);
+ if (it != object_refs_.end())
+ {
+ return it->second.descriptor.get();
+ }
+ }
+ {
+ auto it = objects_.find(object_name);
+ if (it != objects_.end())
+ {
+ return it->second.descriptor.get();
+ }
+ }
+ return nullptr;
+}
+
+absl::Status Arguments::ResolveSelector(const std::map<std::string, std::string> &linkables,
+ const std::string &object_name, const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result)
+{
+ const GPUObjectDescriptor *desc_ptr = GetObjectDescriptor(object_name);
+ if (!desc_ptr)
+ {
+ return absl::NotFoundError(absl::StrCat("No object with name - ", object_name));
+ }
+ auto names = desc_ptr->GetGPUResources().GetNames();
+ const auto *tensor_desc = dynamic_cast<const TensorDescriptor *>(desc_ptr);
+ if (tensor_desc && selector == "Write")
+ {
+ auto it = linkables.find(object_name);
+ if (it != linkables.end())
+ {
+ if (desc_ptr->GetAccess() != AccessType::WRITE &&
+ desc_ptr->GetAccess() != AccessType::READ_WRITE)
+ {
+ return absl::FailedPreconditionError(
+ absl::StrCat("Object with name - ", object_name, " should have Write access."));
+ }
+ std::string value_name, x_coord, y_coord, s_coord;
+ RETURN_IF_ERROR(tensor_desc->GetLinkingContextFromWriteSelector(args, &value_name, &x_coord,
+ &y_coord, &s_coord));
+ // x_coord can have batch size property of link_object
+ ResolveObjectNames(object_name, names, &x_coord);
+ *result = it->second;
+ ReplaceAllWords("in_out_value", value_name, result);
+ ReplaceAllWords("X_COORD", x_coord, result);
+ ReplaceAllWords("Y_COORD", y_coord, result);
+ ReplaceAllWords("S_COORD", s_coord, result);
+ RETURN_IF_ERROR(ResolveSelectorsPass({}, result));
+ }
+ }
+ std::string patch;
+ RETURN_IF_ERROR(desc_ptr->PerformSelector(selector, args, template_args, &patch));
+ ResolveObjectNames(object_name, names, &patch);
+ *result += patch;
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::ResolveSelectorsPass(const std::map<std::string, std::string> &linkables,
+ std::string *code)
+{
+ std::string result;
+ size_t position = 0;
+ size_t next_position = code->find(kArgsPrefix);
+ while (next_position != std::string::npos)
+ {
+ size_t arg_pos = next_position;
+ next_position += strlen(kArgsPrefix);
+ std::string object_name = GetNextWord(*code, next_position);
+ char next = (*code)[next_position + object_name.size()];
+ if (next == '.')
+ {
+ next_position += object_name.size() + 1;
+ std::string selector_name = GetNextWord(*code, next_position);
+ next_position += selector_name.size();
+ next = (*code)[next_position];
+ std::vector<std::string> template_args;
+ if (next == '<')
+ {
+ size_t close_bracket_pos;
+ RETURN_IF_ERROR(
+ ParseArgsInsideBrackets(*code, next_position, &close_bracket_pos, &template_args));
+ next_position = close_bracket_pos;
+ next = (*code)[next_position];
+ }
+ if (next != '(')
+ {
+ return absl::NotFoundError(
+ absl::StrCat("Expected ( after ", object_name, ".", selector_name, " call"));
+ }
+ std::vector<std::string> args;
+ size_t close_bracket_pos;
+ RETURN_IF_ERROR(ParseArgsInsideBrackets(*code, next_position, &close_bracket_pos, &args));
+ for (auto &arg : args)
+ {
+ RETURN_IF_ERROR(ResolveSelectorsPass({}, &arg));
+ }
+ std::string patch;
+ RETURN_IF_ERROR(
+ ResolveSelector(linkables, object_name, selector_name, args, template_args, &patch));
+ code->replace(arg_pos, close_bracket_pos - arg_pos, patch);
+ position = arg_pos + patch.size();
+ }
+ else
+ {
+ position = arg_pos + strlen(kArgsPrefix);
+ }
+ next_position = code->find(kArgsPrefix, position);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status Arguments::AllocateObjects(CLContext *context)
+{
+ for (auto &t : objects_)
+ {
+ RETURN_IF_ERROR(t.second.descriptor->CreateGPUObject(context, &t.second.obj_ptr));
+ }
+ return absl::OkStatus();
+}
+
+void Arguments::ReleaseCPURepresentation()
+{
+ for (auto &t : objects_)
+ {
+ t.second.descriptor->Release();
+ }
+}
+
+absl::Status Arguments::AddObjectArgs()
+{
+ for (auto &t : objects_)
+ {
+ AddGPUResources(t.first, t.second.descriptor->GetGPUResources());
+ GPUResourcesWithValue resources;
+ RETURN_IF_ERROR(t.second.obj_ptr->GetGPUResources(t.second.descriptor.get(), &resources));
+ RETURN_IF_ERROR(SetGPUResources(t.first, resources));
+ }
+ for (auto &t : object_refs_)
+ {
+ AddGPUResources(t.first, t.second.descriptor->GetGPUResources());
+ }
+ return absl::OkStatus();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_ARGUMENTS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_ARGUMENTS_H__
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "ClDevice.h"
+#include "GpuObject.h"
+#include "OpenclWrapper.h"
+
+#include "AccessType.h"
+#include "Types.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ArgumentsBinder
+{
+public:
+ virtual absl::Status SetInt(const std::string &name, int value) = 0;
+ virtual absl::Status SetFloat(const std::string &name, float value) = 0;
+ virtual ~ArgumentsBinder() = default;
+};
+
+class Arguments : public ArgumentsBinder
+{
+public:
+ Arguments() = default;
+ void AddFloat(const std::string &name, float value = 0.0f);
+ void AddInt(const std::string &name, int value = 0);
+ void AddObjectRef(const std::string &name, AccessType access_type,
+ GPUObjectDescriptorPtr &&descriptor_ptr);
+ void AddObject(const std::string &name, GPUObjectDescriptorPtr &&descriptor_ptr);
+
+ absl::Status SetInt(const std::string &name, int value) override;
+ absl::Status SetFloat(const std::string &name, float value) override;
+ absl::Status SetObjectRef(const std::string &name, const GPUObject *object);
+
+ absl::Status Bind(cl_kernel kernel, int offset = 0);
+
+ void RenameArgs(const std::string &postfix, std::string *code) const;
+ absl::Status Merge(Arguments &&args, const std::string &postfix);
+
+ absl::Status AllocateObjects(CLContext *context);
+ void ReleaseCPURepresentation();
+ absl::Status TransformToCLCode(const DeviceInfo &device_info,
+ const std::map<std::string, std::string> &linkables,
+ std::string *code);
+
+ // Move only
+ Arguments(Arguments &&args);
+ Arguments &operator=(Arguments &&args);
+ Arguments(const Arguments &) = delete;
+ Arguments &operator=(const Arguments &) = delete;
+
+ ~Arguments() override = default;
+
+private:
+ void AddBuffer(const std::string &name, const GPUBufferDescriptor &desc);
+ void AddImage2D(const std::string &name, const GPUImage2DDescriptor &desc);
+ void AddImage2DArray(const std::string &name, const GPUImage2DArrayDescriptor &desc);
+ void AddImage3D(const std::string &name, const GPUImage3DDescriptor &desc);
+ void AddImageBuffer(const std::string &name, const GPUImageBufferDescriptor &desc);
+ void AddCustomMemory(const std::string &name, const GPUCustomMemoryDescriptor &desc);
+
+ absl::Status SetImage2D(const std::string &name, cl_mem memory);
+ absl::Status SetBuffer(const std::string &name, cl_mem memory);
+ absl::Status SetImage2DArray(const std::string &name, cl_mem memory);
+ absl::Status SetImage3D(const std::string &name, cl_mem memory);
+ absl::Status SetImageBuffer(const std::string &name, cl_mem memory);
+ absl::Status SetCustomMemory(const std::string &name, cl_mem memory);
+
+ std::string GetListOfArgs();
+
+ std::string AddActiveArgument(const std::string &arg_name, bool use_f32_for_halfs);
+ void AddGPUResources(const std::string &name, const GPUResources &resources);
+
+ absl::Status SetGPUResources(const std::string &name, const GPUResourcesWithValue &resources);
+
+ absl::Status AddObjectArgs();
+
+ void ResolveArgsPass(const DeviceInfo &device_info, std::string *code);
+ absl::Status ResolveSelectorsPass(const std::map<std::string, std::string> &linkables,
+ std::string *code);
+
+ absl::Status ResolveSelector(const std::map<std::string, std::string> &linkables,
+ const std::string &object_name, const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args, std::string *result);
+
+ void ResolveObjectNames(const std::string &object_name,
+ const std::vector<std::string> &member_names, std::string *code);
+
+ GPUObjectDescriptor *GetObjectDescriptor(const std::string &object_name) const;
+
+ static constexpr char kArgsPrefix[] = "args.";
+
+ struct IntValue
+ {
+ int value;
+
+ // many uniforms generated automatically and not used
+ // to reduce amount of data transferred we adding this optimization
+ bool active = false;
+
+ // offset to shared uniform storage.
+ uint32_t offset = -1;
+ };
+ std::map<std::string, IntValue> int_values_;
+ std::vector<int32_t> shared_int4s_data_;
+
+ struct FloatValue
+ {
+ float value;
+
+ // many uniforms generated automatically and not used
+ // to reduce amount of data transferred we adding this optimization
+ bool active = false;
+
+ // offset to shared uniform storage.
+ uint32_t offset = -1;
+ };
+ std::map<std::string, FloatValue> float_values_;
+ std::vector<float> shared_float4s_data_;
+
+ std::map<std::string, GPUBufferDescriptor> buffers_;
+ std::map<std::string, GPUImage2DDescriptor> images2d_;
+ std::map<std::string, GPUImage2DArrayDescriptor> image2d_arrays_;
+ std::map<std::string, GPUImage3DDescriptor> images3d_;
+ std::map<std::string, GPUImageBufferDescriptor> image_buffers_;
+ std::map<std::string, GPUCustomMemoryDescriptor> custom_memories_;
+
+ struct ObjectRefArg
+ {
+ GPUObjectDescriptorPtr descriptor;
+ };
+ std::map<std::string, ObjectRefArg> object_refs_;
+
+ struct ObjectArg
+ {
+ GPUObjectPtr obj_ptr;
+ GPUObjectDescriptorPtr descriptor;
+ };
+ std::map<std::string, ObjectArg> objects_;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_ARGUMENTS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Buffer.h"
+
+#include <string>
+
+#include "ClContext.h"
+#include "DataType.h"
+#include "GpuObject.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void *data,
+ CLContext *context, Buffer *result)
+{
+ cl_mem buffer;
+ RETURN_IF_ERROR(CreateCLBuffer(context->context(), size_in_bytes, gpu_read_only,
+ const_cast<void *>(data), &buffer));
+ *result = Buffer(buffer, size_in_bytes);
+
+ return absl::OkStatus();
+}
+
+} // namespace
+
+BufferDescriptor::BufferDescriptor(BufferDescriptor &&desc)
+ : GPUObjectDescriptor(std::move(desc)), element_type(desc.element_type),
+ element_size(desc.element_size), memory_type(desc.memory_type),
+ attributes(std::move(desc.attributes)), size(desc.size), data(std::move(desc.data))
+{
+}
+
+BufferDescriptor &BufferDescriptor::operator=(BufferDescriptor &&desc)
+{
+ if (this != &desc)
+ {
+ std::swap(element_type, desc.element_type);
+ std::swap(element_size, desc.element_size);
+ std::swap(memory_type, desc.memory_type);
+ attributes = std::move(desc.attributes);
+ std::swap(size, desc.size);
+ data = std::move(desc.data);
+ GPUObjectDescriptor::operator=(std::move(desc));
+ }
+ return *this;
+}
+
+void BufferDescriptor::Release() { data.clear(); }
+
+GPUResources BufferDescriptor::GetGPUResources() const
+{
+ GPUResources resources;
+ GPUBufferDescriptor desc;
+ desc.data_type = element_type;
+ desc.access_type = access_type_;
+ desc.element_size = element_size;
+ desc.memory_type = memory_type;
+ desc.attributes = attributes;
+ resources.buffers.push_back({"buffer", desc});
+ return resources;
+}
+
+absl::Status BufferDescriptor::PerformSelector(const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const
+{
+ if (selector == "Read")
+ {
+ return PerformReadSelector(args, result);
+ }
+ else if (selector == "GetPtr")
+ {
+ return PerformGetPtrSelector(args, template_args, result);
+ }
+ else
+ {
+ return absl::NotFoundError(
+ absl::StrCat("BufferDescriptor don't have selector with name - ", selector));
+ }
+}
+
+absl::Status BufferDescriptor::PerformReadSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (args.size() != 1)
+ {
+ return absl::NotFoundError(
+ absl::StrCat("BufferDescriptor Read require one argument, but ", args.size(), " was passed"));
+ }
+ *result = absl::StrCat("buffer[", args[0], "]");
+ return absl::OkStatus();
+}
+
+absl::Status BufferDescriptor::PerformGetPtrSelector(const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const
+{
+ if (args.size() > 1)
+ {
+ return absl::NotFoundError(absl::StrCat(
+ "BufferDescriptor GetPtr require one or zero arguments, but ", args.size(), " was passed"));
+ }
+ if (template_args.size() > 1)
+ {
+ return absl::NotFoundError(absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
+ "arguments, but ",
+ template_args.size(), " was passed"));
+ }
+ std::string conversion;
+ if (template_args.size() == 1)
+ {
+ const std::string type_name = ToCLDataType(element_type, element_size);
+ if (type_name != template_args[0])
+ {
+ conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ", template_args[0], "*)&");
+ }
+ }
+ if (args.empty())
+ {
+ *result = absl::StrCat(conversion, "buffer");
+ }
+ else if (conversion.empty())
+ {
+ *result = absl::StrCat("(buffer + ", args[0], ")");
+ }
+ else
+ {
+ *result = absl::StrCat(conversion, "buffer[", args[0], "]");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status BufferDescriptor::CreateGPUObject(CLContext *context, GPUObjectPtr *result) const
+{
+ Buffer gpu_buffer;
+ RETURN_IF_ERROR(gpu_buffer.CreateFromBufferDescriptor(*this, context));
+ *result = absl::make_unique<Buffer>(std::move(gpu_buffer));
+ return absl::OkStatus();
+}
+
+Buffer::Buffer(cl_mem buffer, size_t size_in_bytes) : buffer_(buffer), size_(size_in_bytes) {}
+
+Buffer::Buffer(Buffer &&buffer) : buffer_(buffer.buffer_), size_(buffer.size_)
+{
+ buffer.buffer_ = nullptr;
+ buffer.size_ = 0;
+}
+
+Buffer &Buffer::operator=(Buffer &&buffer)
+{
+ if (this != &buffer)
+ {
+ Release();
+ std::swap(size_, buffer.size_);
+ std::swap(buffer_, buffer.buffer_);
+ }
+ return *this;
+}
+
+void Buffer::Release()
+{
+ if (buffer_)
+ {
+ clReleaseMemObject(buffer_);
+ buffer_ = nullptr;
+ size_ = 0;
+ }
+}
+
+absl::Status Buffer::GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const
+{
+ const auto *buffer_desc = dynamic_cast<const BufferDescriptor *>(obj_ptr);
+ if (!buffer_desc)
+ {
+ return absl::InvalidArgumentError("Expected BufferDescriptor on input.");
+ }
+
+ resources->buffers.push_back({"buffer", buffer_});
+ return absl::OkStatus();
+}
+
+absl::Status Buffer::CreateFromBufferDescriptor(const BufferDescriptor &desc, CLContext *context)
+{
+ bool read_only = desc.memory_type == MemoryType::CONSTANT;
+ uint8_t *data_ptr = desc.data.empty() ? nullptr : const_cast<unsigned char *>(desc.data.data());
+ size_ = desc.size;
+ return CreateCLBuffer(context->context(), desc.size, read_only, data_ptr, &buffer_);
+}
+
+absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext *context, Buffer *result)
+{
+ return CreateBuffer(size_in_bytes, true, nullptr, context, result);
+}
+
+absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void *data, CLContext *context,
+ Buffer *result)
+{
+ return CreateBuffer(size_in_bytes, true, data, context, result);
+}
+
+absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext *context, Buffer *result)
+{
+ return CreateBuffer(size_in_bytes, false, nullptr, context, result);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_BUFFER_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_BUFFER_H__
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+
+#include "ClCommandQueue.h"
+#include "ClContext.h"
+#include "GpuObject.h"
+#include "OpenclWrapper.h"
+#include "DataType.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct BufferDescriptor : public GPUObjectDescriptor
+{
+ DataType element_type;
+ int element_size;
+ MemoryType memory_type = MemoryType::GLOBAL;
+ std::vector<std::string> attributes;
+
+ // optional
+ int size = 0;
+ std::vector<uint8_t> data;
+
+ BufferDescriptor() = default;
+ BufferDescriptor(const BufferDescriptor &) = default;
+ BufferDescriptor &operator=(const BufferDescriptor &) = default;
+ BufferDescriptor(BufferDescriptor &&desc);
+ BufferDescriptor &operator=(BufferDescriptor &&desc);
+
+ absl::Status PerformSelector(const std::string &selector, const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const override;
+
+ GPUResources GetGPUResources() const override;
+ absl::Status PerformReadSelector(const std::vector<std::string> &args, std::string *result) const;
+ absl::Status PerformGetPtrSelector(const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const;
+
+ absl::Status CreateGPUObject(CLContext *context, GPUObjectPtr *result) const override;
+ void Release() override;
+};
+
+// Buffer represent linear GPU data storage with arbitrary data format.
+// Buffer is moveable but not copyable.
+class Buffer : public GPUObject
+{
+public:
+ Buffer() {} // just for using Buffer as a class members
+ Buffer(cl_mem buffer, size_t size_in_bytes);
+
+ // Move only
+ Buffer(Buffer &&buffer);
+ Buffer &operator=(Buffer &&buffer);
+ Buffer(const Buffer &) = delete;
+ Buffer &operator=(const Buffer &) = delete;
+
+ virtual ~Buffer() { Release(); }
+
+ // for profiling and memory statistics
+ uint64_t GetMemorySizeInBytes() const { return size_; }
+
+ cl_mem GetMemoryPtr() const { return buffer_; }
+
+ // Writes data to a buffer. Data should point to a region that
+ // has exact size in bytes as size_in_bytes(constructor parameter).
+ template <typename T> absl::Status WriteData(CLCommandQueue *queue, const std::vector<T> *data);
+
+ // Reads data from Buffer into CPU memory.
+ template <typename T> absl::Status ReadData(CLCommandQueue *queue, std::vector<T> *result) const;
+
+ absl::Status GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const override;
+
+ absl::Status CreateFromBufferDescriptor(const BufferDescriptor &desc, CLContext *context);
+
+private:
+ void Release();
+
+ cl_mem buffer_ = nullptr;
+ size_t size_ = 0;
+};
+
+absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext *context, Buffer *result);
+
+absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void *data, CLContext *context,
+ Buffer *result);
+
+absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext *context, Buffer *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_BUFFER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClCommandQueue.h"
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <vector>
+#include <limits>
+
+#include "absl/strings/str_cat.h"
+#include "ClDevice.h"
+#include "ClEvent.h"
+#include "Util.h"
+#include "Types.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+using namespace std;
+
+CLCommandQueue::CLCommandQueue(cl_command_queue queue, bool has_ownership)
+ : queue_(queue), has_ownership_(has_ownership)
+{
+}
+
+CLCommandQueue::CLCommandQueue(CLCommandQueue &&queue)
+ : queue_(queue.queue_), has_ownership_(queue.has_ownership_)
+{
+ queue.queue_ = nullptr;
+}
+
+CLCommandQueue &CLCommandQueue::operator=(CLCommandQueue &&queue)
+{
+ if (this != &queue)
+ {
+ Release();
+ std::swap(queue_, queue.queue_);
+ has_ownership_ = queue.has_ownership_;
+ }
+ return *this;
+}
+
+CLCommandQueue::~CLCommandQueue() { Release(); }
+
+void CLCommandQueue::Release()
+{
+ if (has_ownership_ && queue_)
+ {
+ clReleaseCommandQueue(queue_);
+ queue_ = nullptr;
+ }
+}
+
+absl::Status CLCommandQueue::Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size, CLEvent *event)
+{
+ std::vector<size_t> local(3);
+ std::vector<size_t> global(3);
+ for (int i = 0; i < 3; ++i)
+ {
+ local[i] = work_group_size[i];
+ global[i] = work_groups_count[i] * work_group_size[i];
+ }
+ cl_event resulting_event;
+ const int error_code =
+ clEnqueueNDRangeKernel(queue_, kernel.kernel(), 3, nullptr, global.data(), local.data(), 0,
+ nullptr, event ? &resulting_event : nullptr);
+ if (event)
+ {
+ *event = CLEvent(resulting_event);
+ }
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to clEnqueueNDRangeKernel - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size)
+{
+ return Dispatch(kernel, work_groups_count, work_group_size, nullptr);
+}
+
+absl::Status CLCommandQueue::EnqueueEvent(CLEvent *event)
+{
+ cl_event resulting_event;
+ const int error_code = clEnqueueMarker(queue_, &resulting_event);
+ *event = CLEvent(resulting_event);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to clEnqueueMarker - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region, const void *data)
+{
+ const size_t origin[] = {0, 0, 0};
+ const size_t r[] = {static_cast<size_t>(region.x), static_cast<size_t>(region.y),
+ static_cast<size_t>(region.z)};
+ auto error_code =
+ clEnqueueWriteImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to upload data to GPU (clEnqueueWriteImage) - ",
+ CLErrorCodeToString(error_code)));
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region, void *data)
+{
+ const size_t origin[] = {0, 0, 0};
+ const size_t r[] = {static_cast<size_t>(region.x), static_cast<size_t>(region.y),
+ static_cast<size_t>(region.z)};
+ auto error_code =
+ clEnqueueReadImage(queue_, memory, CL_TRUE, origin, r, 0, 0, data, 0, nullptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to read data from GPU (clEnqueueReadImage) - ",
+ CLErrorCodeToString(error_code)));
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes,
+ const void *data)
+{
+ auto error_code =
+ clEnqueueWriteBuffer(queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ",
+ CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, void *data)
+{
+ auto error_code =
+ clEnqueueReadBuffer(queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ",
+ CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CLCommandQueue::WaitForCompletion()
+{
+ auto error_code = clFinish(queue_);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to clFinish - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+ProfilingCommandQueue::ProfilingCommandQueue(cl_command_queue queue) : CLCommandQueue(queue, true)
+{
+ events_.reserve(128);
+}
+
+ProfilingCommandQueue::ProfilingCommandQueue(ProfilingCommandQueue &&queue)
+ : CLCommandQueue(std::move(queue)), events_(std::move(queue.events_)),
+ current_label_(std::move(queue.current_label_))
+{
+}
+
+ProfilingCommandQueue &ProfilingCommandQueue::operator=(ProfilingCommandQueue &&queue)
+{
+ if (this != &queue)
+ {
+ events_ = std::move(queue.events_);
+ current_label_ = std::move(queue.current_label_);
+ CLCommandQueue::operator=(std::move(queue));
+ }
+ return *this;
+}
+
+void ProfilingCommandQueue::SetEventsLabel(const std::string &name) { current_label_ = name; }
+
+void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); }
+
+absl::Status ProfilingCommandQueue::Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size)
+{
+ events_.push_back(CLEvent());
+ RETURN_IF_ERROR(CLCommandQueue::Dispatch(kernel, work_groups_count, work_group_size,
+ &events_[events_.size() - 1]));
+ events_.back().SetName(current_label_);
+ return absl::OkStatus();
+}
+
+absl::Status
+ProfilingCommandQueue::GetBestWorkGroupIndex(const CLKernel &kernel, const DeviceInfo &device_info,
+ const std::vector<int3> &work_groups_count,
+ const std::vector<int3> &work_group_sizes, int *index)
+{
+ // Some Adreno 3xx can have wrong numbers for some events
+ const bool possible_bug_with_events = device_info.IsAdreno3xx();
+ events_.resize(work_group_sizes.size());
+ for (size_t i = 0; i < work_group_sizes.size(); ++i)
+ {
+ RETURN_IF_ERROR(
+ CLCommandQueue::Dispatch(kernel, work_groups_count[i], work_group_sizes[i], &events_[i]));
+
+ // reducing the speed of memory leak on Mali for some kernels
+ if (device_info.IsMali() && i % 8 == 7)
+ {
+ events_[i - 7].Wait();
+ }
+ if (possible_bug_with_events)
+ {
+ // We are trying to increase probability for correct result.
+ RETURN_IF_ERROR(WaitForCompletion());
+ }
+ }
+
+ RETURN_IF_ERROR(WaitForCompletion());
+
+ // To release memory of some kernel pool on Mali.
+ if (device_info.IsMali())
+ {
+ RETURN_IF_ERROR(kernel.ReInit());
+ }
+
+ int minimum_index = 0;
+ double minimum_time = std::numeric_limits<double>::max();
+ if (possible_bug_with_events)
+ { // we will try to cut out suspicious results
+ double average_time = 0.0;
+ int average_samples_count = 0;
+ for (size_t i = 0; i < work_group_sizes.size(); ++i)
+ {
+ if (events_[i].GetEventTimeMs() < 100 * 1000)
+ { // 100 sec
+ average_time += events_[i].GetEventTimeMs();
+ average_samples_count++;
+ }
+ }
+ if (average_samples_count == 0)
+ {
+ throw std::runtime_error("It cannot be divided by zero");
+ }
+ else
+ {
+ average_time /= average_samples_count;
+ }
+
+ for (size_t i = 0; i < work_group_sizes.size(); ++i)
+ {
+ double time = events_[i].GetEventTimeMs();
+ if (time < minimum_time && time >= 0.1 * average_time)
+ {
+ minimum_index = i;
+ minimum_time = time;
+ }
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < work_group_sizes.size(); ++i)
+ {
+ double time = events_[i].GetEventTimeMs();
+ if (time < minimum_time)
+ {
+ minimum_index = i;
+ minimum_time = time;
+ }
+ }
+ }
+
+ *index = minimum_index;
+
+ return absl::OkStatus();
+}
+
+absl::Status CreateCLCommandQueue(const CLDevice &device, const CLContext &context,
+ CLCommandQueue *result)
+{
+ int error_code;
+ cl_command_queue queue = clCreateCommandQueue(context.context(), device.id(), 0, &error_code);
+ if (!queue)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to create a command queue - ", CLErrorCodeToString(error_code)));
+ }
+ *result = CLCommandQueue(queue, true);
+ return absl::OkStatus();
+}
+
+double ProfilingCommandQueue::GetQueueExecutionTimeMs() const
+{
+ const uint64_t start = events_.front().GetStartedTimeNs();
+ const uint64_t end = events_.back().GetFinishedTimeNs();
+ const uint64_t time_ns = (end - start);
+
+ return static_cast<double>(time_ns) / 1000000.0;
+}
+
+double ProfilingCommandQueue::GetSumOfEventsTimeMs() const
+{
+ double sum = 0.0;
+ for (uint32_t i = 0; i < events_.size(); ++i)
+ {
+ sum += events_[i].GetEventTimeMs();
+ }
+ return sum;
+}
+
+absl::Status CreateProfilingCommandQueue(const CLDevice &device, const CLContext &context,
+ ProfilingCommandQueue *result)
+{
+ int error_code;
+ cl_command_queue queue =
+ clCreateCommandQueue(context.context(), device.id(), CL_QUEUE_PROFILING_ENABLE, &error_code);
+ if (!queue)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to create a command queue - ", CLErrorCodeToString(error_code)));
+ }
+
+ *result = ProfilingCommandQueue(queue);
+ return absl::OkStatus();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_COMMAND_QUEUE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_COMMAND_QUEUE_H__
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "absl/time/time.h"
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "ClEvent.h"
+#include "ClKernel.h"
+#include "OpenclWrapper.h"
+#include "Types.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct ProfilingInfo
+{
+ struct DispatchInfo
+ {
+ std::string label;
+ absl::Duration duration;
+ };
+
+ std::vector<DispatchInfo> dispatches;
+
+ absl::Duration GetTotalTime() const;
+
+ // Returns report (string of lines delimited by \n)
+ // This method uses GPU counters and measure GPU time only.
+ // Report has next structure:
+ // Per kernel timing(K kernels):
+ // conv2d 3.2ms
+ // ...
+ // --------------------
+ // Accumulated time per operation type:
+ // conv2d - 14.5ms
+ // ....
+ // --------------------
+ // Ideal total time: 23.4ms // Total time for all kernels
+ std::string GetDetailedReport() const;
+};
+
+// A wrapper around opencl command queue
+class CLCommandQueue
+{
+public:
+ CLCommandQueue() {}
+ CLCommandQueue(cl_command_queue queue, bool has_ownership);
+
+ // Move only
+ CLCommandQueue(CLCommandQueue &&queue);
+ CLCommandQueue &operator=(CLCommandQueue &&queue);
+ CLCommandQueue(const CLCommandQueue &) = delete;
+ CLCommandQueue &operator=(const CLCommandQueue &) = delete;
+
+ virtual ~CLCommandQueue();
+
+ cl_command_queue queue() const { return queue_; }
+
+ virtual absl::Status Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size);
+
+ absl::Status Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size, CLEvent *event);
+
+ absl::Status EnqueueEvent(CLEvent *event);
+
+ absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void *data);
+ absl::Status EnqueueReadImage(cl_mem memory, int3 region, void *data);
+
+ absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes, const void *data);
+ absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, void *data);
+
+ absl::Status WaitForCompletion();
+
+protected:
+ void Release();
+
+ cl_command_queue queue_ = nullptr;
+ bool has_ownership_ = false;
+};
+
+class ProfilingCommandQueue : public CLCommandQueue
+{
+public:
+ ProfilingCommandQueue() {}
+ explicit ProfilingCommandQueue(cl_command_queue queue);
+
+ // Move only
+ ProfilingCommandQueue(ProfilingCommandQueue &&queue);
+ ProfilingCommandQueue &operator=(ProfilingCommandQueue &&queue);
+ ProfilingCommandQueue(const ProfilingCommandQueue &) = delete;
+ ProfilingCommandQueue &operator=(const ProfilingCommandQueue &) = delete;
+
+ absl::Status Dispatch(const CLKernel &kernel, const int3 &work_groups_count,
+ const int3 &work_group_size) override;
+
+ // will write index for fastest work_group among work_group_sizes
+ absl::Status GetBestWorkGroupIndex(const CLKernel &kernel, const DeviceInfo &device_info,
+ const std::vector<int3> &work_groups_count,
+ const std::vector<int3> &work_group_sizes, int *index);
+
+ // call ResetMeasurements() to start new seriese of measurements
+ void ResetMeasurements();
+
+ double GetQueueExecutionTimeMs() const;
+
+ // Difference from GetQueueExecutionTimeMs is that this number doesn't include
+ // time between kernels(kernels launches or preparing) on GPU. Usually, this
+ // time should be 5-10% better than GetQueueExecutionTimeMs, because 5-10%
+ // spend on something else(maybe kernels launches or preparing)
+ double GetSumOfEventsTimeMs() const;
+
+ // This label will be used for all subsequent dispatches.
+ void SetEventsLabel(const std::string &name);
+
+private:
+ std::vector<CLEvent> events_;
+ std::string current_label_;
+};
+
+absl::Status CreateCLCommandQueue(const CLDevice &device, const CLContext &context,
+ CLCommandQueue *result);
+
+absl::Status CreateProfilingCommandQueue(const CLDevice &device, const CLContext &context,
+ ProfilingCommandQueue *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_COMMAND_QUEUE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClContext.h"
+
+#include "absl/strings/str_cat.h"
+#include "ClImageFormat.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::vector<cl_image_format> GetSupportedImage2DFormats(cl_context context, cl_mem_flags flags)
+{
+ cl_uint num_image_formats;
+ cl_int error = clGetSupportedImageFormats(context, flags, CL_MEM_OBJECT_IMAGE2D, 0, nullptr,
+ &num_image_formats);
+ if (error != CL_SUCCESS)
+ {
+ return {};
+ }
+
+ std::vector<cl_image_format> result(num_image_formats);
+ error = clGetSupportedImageFormats(context, flags, CL_MEM_OBJECT_IMAGE2D, num_image_formats,
+ &result[0], nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return {};
+ }
+ return result;
+}
+
+bool IsEqualToImageFormat(cl_image_format image_format, DataType data_type, int num_channels)
+{
+ return image_format.image_channel_data_type == ToImageChannelType(data_type) &&
+ image_format.image_channel_order == ToChannelOrder(num_channels);
+}
+
+void AddSupportedImageFormats(cl_context context, DeviceInfo *info)
+{
+ auto supported_formats = GetSupportedImage2DFormats(context, CL_MEM_READ_WRITE);
+ for (auto format : supported_formats)
+ {
+ info->supports_r_f16_tex2d =
+ info->supports_r_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 1);
+ info->supports_rg_f16_tex2d =
+ info->supports_rg_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 2);
+ info->supports_rgb_f16_tex2d =
+ info->supports_rgb_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 3);
+ info->supports_rgba_f16_tex2d =
+ info->supports_rgba_f16_tex2d || IsEqualToImageFormat(format, DataType::FLOAT16, 4);
+ info->supports_r_f32_tex2d =
+ info->supports_r_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 1);
+ info->supports_rg_f32_tex2d =
+ info->supports_rg_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 2);
+ info->supports_rgb_f32_tex2d =
+ info->supports_rgb_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 3);
+ info->supports_rgba_f32_tex2d =
+ info->supports_rgba_f32_tex2d || IsEqualToImageFormat(format, DataType::FLOAT32, 4);
+ }
+}
+
+absl::Status CreateCLContext(const CLDevice &device, cl_context_properties *properties,
+ CLContext *result)
+{
+ int error_code;
+ cl_device_id device_id = device.id();
+ cl_context context = clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code);
+ if (!context)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to create a compute context - ", CLErrorCodeToString(error_code)));
+ }
+ AddSupportedImageFormats(context, &device.info_);
+
+ *result = CLContext(context, true);
+ return absl::OkStatus();
+}
+
+} // namespace
+
+CLContext::CLContext(cl_context context, bool has_ownership)
+ : context_(context), has_ownership_(has_ownership)
+{
+}
+
+CLContext::CLContext(CLContext &&context)
+ : context_(context.context_), has_ownership_(context.has_ownership_)
+{
+ context.context_ = nullptr;
+}
+
+CLContext &CLContext::operator=(CLContext &&context)
+{
+ if (this != &context)
+ {
+ Release();
+ std::swap(context_, context.context_);
+ has_ownership_ = context.has_ownership_;
+ }
+ return *this;
+}
+
+CLContext::~CLContext() { Release(); }
+
+void CLContext::Release()
+{
+ if (has_ownership_ && context_)
+ {
+ clReleaseContext(context_);
+ context_ = nullptr;
+ }
+}
+
+bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type,
+ cl_mem_flags flags) const
+{
+ auto supported_formats = GetSupportedImage2DFormats(context_, flags);
+ for (auto format : supported_formats)
+ {
+ if (format.image_channel_data_type == ToImageChannelType(data_type) &&
+ format.image_channel_order == ToChannelOrder(num_channels))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+absl::Status CreateCLContext(const CLDevice &device, CLContext *result)
+{
+ return CreateCLContext(device, nullptr, result);
+}
+
+absl::Status CreateCLGLContext(const CLDevice &device, cl_context_properties egl_context,
+ cl_context_properties egl_display, CLContext *result)
+{
+ if (!device.SupportsExtension("cl_khr_gl_sharing"))
+ {
+ return absl::UnavailableError("Device doesn't support CL-GL sharing.");
+ }
+ cl_context_properties platform = reinterpret_cast<cl_context_properties>(device.platform());
+ cl_context_properties props[] = {CL_GL_CONTEXT_KHR,
+ egl_context,
+ CL_EGL_DISPLAY_KHR,
+ egl_display,
+ CL_CONTEXT_PLATFORM,
+ platform,
+ 0};
+ return CreateCLContext(device, props, result);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_CONTEXT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_CONTEXT_H__
+
+#include "ClDevice.h"
+#include "OpenclWrapper.h"
+#include "DataType.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// A RAII wrapper around opencl context
+class CLContext
+{
+public:
+ CLContext() {}
+ CLContext(cl_context context, bool has_ownership);
+
+ // Move only
+ CLContext(CLContext &&context);
+ CLContext &operator=(CLContext &&context);
+ CLContext(const CLContext &) = delete;
+ CLContext &operator=(const CLContext &) = delete;
+
+ ~CLContext();
+
+ cl_context context() const { return context_; }
+
+ bool IsFloatTexture2DSupported(int num_channels, DataType data_type,
+ cl_mem_flags flags = CL_MEM_READ_WRITE) const;
+
+private:
+ void Release();
+
+ cl_context context_ = nullptr;
+ bool has_ownership_ = false;
+};
+
+absl::Status CreateCLContext(const CLDevice &device, CLContext *result);
+absl::Status CreateCLGLContext(const CLDevice &device, cl_context_properties egl_context,
+ cl_context_properties egl_display, CLContext *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_CONTEXT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClDevice.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "Util.h"
+#include "Status.h"
+
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <> std::string GetDeviceInfo<std::string>(cl_device_id id, cl_device_info info)
+{
+ size_t size;
+ cl_int error = clGetDeviceInfo(id, info, 0, nullptr, &size);
+ if (error != CL_SUCCESS)
+ {
+ return "";
+ }
+
+ std::string result(size - 1, 0);
+ error = clGetDeviceInfo(id, info, size, &result[0], nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return "";
+ }
+ return result;
+}
+
+namespace
+{
+template <typename T> T GetPlatformInfo(cl_platform_id id, cl_platform_info info)
+{
+ T result;
+ cl_int error = clGetPlatformInfo(id, info, sizeof(T), &result, nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return -1;
+ }
+ return result;
+}
+
+std::string GetPlatformInfo(cl_platform_id id, cl_platform_info info)
+{
+ size_t size;
+ cl_int error = clGetPlatformInfo(id, info, 0, nullptr, &size);
+ if (error != CL_SUCCESS)
+ {
+ return "";
+ }
+
+ std::string result(size - 1, 0);
+ error = clGetPlatformInfo(id, info, size, &result[0], nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return "";
+ }
+ return result;
+}
+
+void GetDeviceWorkDimsSizes(cl_device_id id, int3 *result)
+{
+ int dims_count = GetDeviceInfo<cl_uint>(id, CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS);
+ if (dims_count < 3)
+ {
+ return;
+ }
+ std::vector<size_t> limits(dims_count);
+ cl_int error = clGetDeviceInfo(id, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(size_t) * dims_count,
+ limits.data(), nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return;
+ }
+ // dims_count must be at least 3 according to spec
+ result->x = limits[0];
+ result->y = limits[1];
+ result->z = limits[2];
+}
+
+OpenCLVersion ParseCLVersion(const std::string &version)
+{
+ const auto first_dot_pos = version.find_first_of('.');
+ if (first_dot_pos == std::string::npos)
+ {
+ return OpenCLVersion::CL_1_0;
+ }
+ const int major = version[first_dot_pos - 1] - '0';
+ const int minor = version[first_dot_pos + 1] - '0';
+
+ if (major == 1)
+ {
+ if (minor == 2)
+ {
+ return OpenCLVersion::CL_1_2;
+ }
+ else if (minor == 1)
+ {
+ return OpenCLVersion::CL_1_1;
+ }
+ else
+ {
+ return OpenCLVersion::CL_1_0;
+ }
+ }
+ else if (major == 2)
+ {
+ if (minor == 2)
+ {
+ return OpenCLVersion::CL_2_2;
+ }
+ else if (minor == 1)
+ {
+ return OpenCLVersion::CL_2_1;
+ }
+ else
+ {
+ return OpenCLVersion::CL_2_0;
+ }
+ }
+ else if (major == 3)
+ {
+ return OpenCLVersion::CL_3_0;
+ }
+ else
+ {
+ return OpenCLVersion::CL_1_0;
+ }
+}
+
+Vendor ParseVendor(const std::string &device_name, const std::string &vendor_name)
+{
+ std::string d_name = device_name;
+ std::string v_name = vendor_name;
+ std::transform(d_name.begin(), d_name.end(), d_name.begin(), ::tolower);
+ std::transform(v_name.begin(), v_name.end(), v_name.begin(), ::tolower);
+ if (d_name.find("qualcomm") != std::string::npos || v_name.find("qualcomm") != std::string::npos)
+ {
+ return Vendor::kQualcomm;
+ }
+ else if (d_name.find("mali") != std::string::npos || v_name.find("mali") != std::string::npos)
+ {
+ return Vendor::kMali;
+ }
+ else if (d_name.find("power") != std::string::npos || v_name.find("power") != std::string::npos)
+ {
+ return Vendor::kPowerVR;
+ }
+ else if (d_name.find("nvidia") != std::string::npos || v_name.find("nvidia") != std::string::npos)
+ {
+ return Vendor::kNvidia;
+ }
+ else if (d_name.find("advanced micro devices") != std::string::npos ||
+ v_name.find("advanced micro devices") != std::string::npos)
+ {
+ return Vendor::kAMD;
+ }
+ else if (d_name.find("intel") != std::string::npos || v_name.find("intel") != std::string::npos)
+ {
+ return Vendor::kIntel;
+ }
+ else
+ {
+ return Vendor::kUnknown;
+ }
+}
+
+// check that gpu_version belong to range min_version-max_version
+// min_version is included and max_version is excluded.
+bool IsGPUVersionInRange(int gpu_version, int min_version, int max_version)
+{
+ return gpu_version >= min_version && gpu_version < max_version;
+}
+} // namespace
+
+DeviceInfo DeviceInfoFromDeviceID(cl_device_id id)
+{
+ DeviceInfo info;
+ const auto device_name = GetDeviceInfo<std::string>(id, CL_DEVICE_NAME);
+ const auto vendor_name = GetDeviceInfo<std::string>(id, CL_DEVICE_VENDOR);
+ const auto opencl_c_version = GetDeviceInfo<std::string>(id, CL_DEVICE_OPENCL_C_VERSION);
+ info.vendor = ParseVendor(device_name, vendor_name);
+ if (info.vendor == Vendor::kQualcomm)
+ {
+ info.adreno_info = AdrenoInfo(opencl_c_version);
+ }
+ else if (info.vendor == Vendor::kMali)
+ {
+ info.mali_info = MaliInfo(device_name);
+ }
+ info.cl_version = ParseCLVersion(opencl_c_version);
+ info.extensions = absl::StrSplit(GetDeviceInfo<std::string>(id, CL_DEVICE_EXTENSIONS), ' ');
+
+ info.supports_fp16 = false;
+ info.supports_image3d_writes = false;
+ for (const auto &ext : info.extensions)
+ {
+ if (ext == "cl_khr_fp16")
+ {
+ info.supports_fp16 = true;
+ }
+ if (ext == "cl_khr_3d_image_writes")
+ {
+ info.supports_image3d_writes = true;
+ }
+ }
+
+ cl_device_fp_config f32_config =
+ GetDeviceInfo<cl_device_fp_config>(id, CL_DEVICE_SINGLE_FP_CONFIG);
+ info.supports_fp32_rtn = f32_config & CL_FP_ROUND_TO_NEAREST;
+
+ if (info.supports_fp16)
+ {
+ cl_device_fp_config f16_config;
+ auto status = GetDeviceInfo<cl_device_fp_config>(id, CL_DEVICE_HALF_FP_CONFIG, &f16_config);
+ // AMD supports cl_khr_fp16 but CL_DEVICE_HALF_FP_CONFIG is empty.
+ if (status.ok() && info.vendor != Vendor::kAMD)
+ {
+ info.supports_fp16_rtn = f16_config & CL_FP_ROUND_TO_NEAREST;
+ }
+ else
+ { // happens on PowerVR
+ f16_config = f32_config;
+ info.supports_fp16_rtn = info.supports_fp32_rtn;
+ }
+ }
+ else
+ {
+ info.supports_fp16_rtn = false;
+ }
+
+ if (info.vendor == Vendor::kPowerVR && !info.supports_fp16)
+ {
+ // PowerVR doesn't have full support of fp16 and so doesn't list this
+ // extension. But it can support fp16 in MADs and as buffers/textures types,
+ // so we will use it.
+ info.supports_fp16 = true;
+ info.supports_fp16_rtn = info.supports_fp32_rtn;
+ }
+
+ if (!info.supports_image3d_writes &&
+ ((info.vendor == Vendor::kQualcomm &&
+ IsGPUVersionInRange(info.adreno_info.gpu_version, 400, 500)) ||
+ info.vendor == Vendor::kNvidia))
+ {
+ // in local tests Adreno 430 can write in image 3d, at least on small sizes,
+ // but it doesn't have cl_khr_3d_image_writes in list of available
+ // extensions
+ // The same for NVidia
+ info.supports_image3d_writes = true;
+ }
+ info.compute_units_count = GetDeviceInfo<cl_uint>(id, CL_DEVICE_MAX_COMPUTE_UNITS);
+ info.image2d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_WIDTH);
+ info.image2d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
+ info.buffer_max_size = GetDeviceInfo<cl_ulong>(id, CL_DEVICE_MAX_MEM_ALLOC_SIZE);
+ if (info.cl_version >= OpenCLVersion::CL_1_2)
+ {
+ info.image_buffer_max_size = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE);
+ info.image_array_max_layers = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_ARRAY_SIZE);
+ }
+ info.image3d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_WIDTH);
+ info.image3d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
+ info.image3d_max_depth = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_DEPTH);
+ int3 max_work_group_sizes;
+ GetDeviceWorkDimsSizes(id, &max_work_group_sizes);
+ info.max_work_group_size_x = max_work_group_sizes.x;
+ info.max_work_group_size_y = max_work_group_sizes.y;
+ info.max_work_group_size_z = max_work_group_sizes.z;
+
+ if (info.IsIntel())
+ {
+ if (info.SupportsExtension("cl_intel_required_subgroup_size"))
+ {
+ size_t sub_groups_count;
+ cl_int status = clGetDeviceInfo(id, 0x4108 /*CL_DEVICE_SUB_GROUP_SIZES_INTEL*/, 0, nullptr,
+ &sub_groups_count);
+ if (status == CL_SUCCESS)
+ {
+ std::vector<size_t> sub_group_sizes(sub_groups_count);
+ status =
+ clGetDeviceInfo(id, 0x4108 /*CL_DEVICE_SUB_GROUP_SIZES_INTEL*/,
+ sizeof(size_t) * sub_groups_count, sub_group_sizes.data(), nullptr);
+ if (status == CL_SUCCESS)
+ {
+ for (size_t i = 0; i < sub_groups_count; ++i)
+ {
+ info.supported_subgroup_sizes.push_back(sub_group_sizes[i]);
+ }
+ }
+ }
+ }
+ }
+ return info;
+}
+
+CLDevice::CLDevice(cl_device_id id, cl_platform_id platform_id)
+ : info_(DeviceInfoFromDeviceID(id)), id_(id), platform_id_(platform_id)
+{
+}
+
+CLDevice::CLDevice(const CLDevice &device)
+ : info_(device.info_), id_(device.id_), platform_id_(device.platform_id_)
+{
+}
+
+CLDevice &CLDevice::operator=(const CLDevice &device)
+{
+ if (this != &device)
+ {
+ info_ = device.info_;
+ id_ = device.id_;
+ platform_id_ = device.platform_id_;
+ }
+ return *this;
+}
+
+CLDevice::CLDevice(CLDevice &&device)
+ : info_(std::move(device.info_)), id_(device.id_), platform_id_(device.platform_id_)
+{
+ device.id_ = nullptr;
+ device.platform_id_ = nullptr;
+}
+
+CLDevice &CLDevice::operator=(CLDevice &&device)
+{
+ if (this != &device)
+ {
+ id_ = nullptr;
+ platform_id_ = nullptr;
+ info_ = std::move(device.info_);
+ std::swap(id_, device.id_);
+ std::swap(platform_id_, device.platform_id_);
+ }
+ return *this;
+}
+
+bool CLDevice::SupportsFP16() const { return info_.supports_fp16; }
+
+bool CLDevice::SupportsExtension(const std::string &extension) const
+{
+ return info_.SupportsExtension(extension);
+}
+
+bool CLDevice::SupportsTextureArray() const { return info_.SupportsTextureArray(); }
+
+bool CLDevice::SupportsImageBuffer() const { return info_.SupportsImageBuffer(); }
+
+bool CLDevice::SupportsImage3D() const { return info_.SupportsImage3D(); }
+
+bool CLDevice::SupportsFP32RTN() const { return info_.supports_fp32_rtn; }
+
+bool CLDevice::SupportsFP16RTN() const { return info_.supports_fp16_rtn; }
+
+std::string CLDevice::GetPlatformVersion() const
+{
+ return GetPlatformInfo(platform_id_, CL_PLATFORM_VERSION);
+}
+
+bool CLDevice::IsCL20OrHigher() const { return info_.IsCL20OrHigher(); }
+
+bool CLDevice::SupportsSubGroupWithSize(int sub_group_size) const
+{
+ return info_.SupportsSubGroupWithSize(sub_group_size);
+}
+
+bool CLDevice::IsAdreno() const { return info_.IsAdreno(); }
+
+bool CLDevice::IsAdreno3xx() const { return info_.IsAdreno3xx(); }
+
+bool CLDevice::IsAdreno4xx() const { return info_.IsAdreno4xx(); }
+
+bool CLDevice::IsAdreno5xx() const { return info_.IsAdreno5xx(); }
+
+bool CLDevice::IsAdreno6xx() const { return info_.IsAdreno6xx(); }
+
+bool CLDevice::IsAdreno6xxOrHigher() const { return info_.IsAdreno6xxOrHigher(); }
+
+bool CLDevice::IsPowerVR() const { return info_.IsPowerVR(); }
+
+bool CLDevice::IsNvidia() const { return info_.IsNvidia(); }
+
+bool CLDevice::IsMali() const { return info_.IsMali(); }
+
+bool CLDevice::IsAMD() const { return info_.IsAMD(); }
+
+bool CLDevice::IsIntel() const { return info_.IsIntel(); }
+
+bool CLDevice::SupportsOneLayerTextureArray() const { return info_.SupportsOneLayerTextureArray(); }
+
+void CLDevice::DisableOneLayerTextureArray()
+{
+ info_.adreno_info.support_one_layer_texture_array = false;
+}
+
+absl::Status CreateDefaultGPUDevice(CLDevice *result)
+{
+ cl_uint num_platforms;
+ clGetPlatformIDs(0, nullptr, &num_platforms);
+ if (num_platforms == 0)
+ {
+ return absl::UnknownError("No supported OpenCL platform.");
+ }
+ std::vector<cl_platform_id> platforms(num_platforms);
+ clGetPlatformIDs(num_platforms, platforms.data(), nullptr);
+
+ cl_platform_id platform_id = platforms[0];
+ cl_uint num_devices;
+ clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices);
+ if (num_devices == 0)
+ {
+ return absl::UnknownError("No GPU on current platform.");
+ }
+
+ std::vector<cl_device_id> devices(num_devices);
+ clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, num_devices, devices.data(), nullptr);
+
+ *result = CLDevice(devices[0], platform_id);
+ return absl::OkStatus();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_DEVICE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_DEVICE_H__
+
+#include <string>
+#include <vector>
+
+#include "DeviceInfo.h"
+#include "OpenclWrapper.h"
+#include "Util.h"
+#include "Types.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// A wrapper around opencl device id
+class CLDevice
+{
+public:
+ CLDevice() = default;
+ CLDevice(cl_device_id id, cl_platform_id platform_id);
+
+ CLDevice(CLDevice &&device);
+ CLDevice &operator=(CLDevice &&device);
+ CLDevice(const CLDevice &);
+ CLDevice &operator=(const CLDevice &);
+
+ ~CLDevice() {}
+
+ cl_device_id id() const { return id_; }
+ cl_platform_id platform() const { return platform_id_; }
+ std::string GetPlatformVersion() const;
+
+ Vendor vendor() const { return info_.vendor; }
+ OpenCLVersion cl_version() const { return info_.cl_version; }
+ bool SupportsFP16() const;
+ bool SupportsTextureArray() const;
+ bool SupportsImageBuffer() const;
+ bool SupportsImage3D() const;
+ bool SupportsExtension(const std::string &extension) const;
+ bool SupportsFP32RTN() const;
+ bool SupportsFP16RTN() const;
+ bool IsCL20OrHigher() const;
+ bool SupportsSubGroupWithSize(int sub_group_size) const;
+ bool IsAdreno() const;
+ bool IsAdreno3xx() const;
+ bool IsAdreno4xx() const;
+ bool IsAdreno5xx() const;
+ bool IsAdreno6xx() const;
+ bool IsAdreno6xxOrHigher() const;
+ bool IsPowerVR() const;
+ bool IsNvidia() const;
+ bool IsMali() const;
+ bool IsAMD() const;
+ bool IsIntel() const;
+
+ // To track bug on some Adreno. b/131099086
+ bool SupportsOneLayerTextureArray() const;
+ void DisableOneLayerTextureArray();
+
+ const DeviceInfo &GetInfo() const { return info_; }
+ // We update device info during context creation, so as supported texture
+ // formats can be requested from context only.
+ mutable DeviceInfo info_;
+
+private:
+ cl_device_id id_ = nullptr;
+ cl_platform_id platform_id_ = nullptr;
+};
+
+absl::Status CreateDefaultGPUDevice(CLDevice *result);
+
+template <typename T> T GetDeviceInfo(cl_device_id id, cl_device_info info)
+{
+ T result;
+ cl_int error = clGetDeviceInfo(id, info, sizeof(T), &result, nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return -1;
+ }
+ return result;
+}
+
+template <typename T> absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T *result)
+{
+ cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr);
+ if (error != CL_SUCCESS)
+ {
+ return absl::InvalidArgumentError(CLErrorCodeToString(error));
+ }
+ return absl::OkStatus();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_DEVICE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_ERRORS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_ERRORS_H__
+
+#include <string>
+
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// @return if error_code is success, then return OK status. Otherwise translates
+// error code into a message.
+inline absl::Status GetOpenCLError(cl_int error_code)
+{
+ if (error_code == CL_SUCCESS)
+ {
+ return absl::OkStatus();
+ }
+ return absl::InternalError("OpenCL error: " + CLErrorCodeToString(error_code));
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_ERRORS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClEvent.h"
+
+#include "OpenclWrapper.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+CLEvent::CLEvent(cl_event event) : event_(event) {}
+
+CLEvent::CLEvent(CLEvent &&event) : event_(event.event_), name_(std::move(event.name_))
+{
+ event.event_ = nullptr;
+}
+
+CLEvent &CLEvent::operator=(CLEvent &&event)
+{
+ if (this != &event)
+ {
+ Release();
+ std::swap(event_, event.event_);
+ name_ = std::move(event.name_);
+ }
+ return *this;
+}
+
+uint64_t CLEvent::GetStartedTimeNs() const
+{
+ cl_ulong time_ns;
+ clGetEventProfilingInfo(event_, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &time_ns, nullptr);
+ return time_ns;
+}
+
+uint64_t CLEvent::GetFinishedTimeNs() const
+{
+ cl_ulong time_ns;
+ clGetEventProfilingInfo(event_, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &time_ns, nullptr);
+ return time_ns;
+}
+
+double CLEvent::GetEventTimeMs() const
+{
+ const uint64_t start = GetStartedTimeNs();
+ const uint64_t end = GetFinishedTimeNs();
+ const uint64_t time_ns = (end - start);
+
+ return static_cast<double>(time_ns) * 1e-6;
+}
+
+uint64_t CLEvent::GetEventTimeNs() const { return GetFinishedTimeNs() - GetStartedTimeNs(); }
+
+void CLEvent::SetName(const std::string &name) { name_ = name; }
+
+void CLEvent::Wait() const { clWaitForEvents(1, &event_); }
+
+CLEvent::~CLEvent() { Release(); }
+
+void CLEvent::Release()
+{
+ if (event_)
+ {
+ clReleaseEvent(event_);
+ event_ = nullptr;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_EVENT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_EVENT_H__
+
+#include <cstdint>
+#include <string>
+
+#include "OpenclWrapper.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// A RAII wrapper around opencl event
+class CLEvent
+{
+public:
+ CLEvent() {}
+ explicit CLEvent(cl_event event);
+
+ // Move only
+ CLEvent(CLEvent &&event);
+ CLEvent &operator=(CLEvent &&event);
+ CLEvent(const CLEvent &) = delete;
+ CLEvent &operator=(const CLEvent &) = delete;
+
+ ~CLEvent();
+
+ uint64_t GetStartedTimeNs() const;
+ uint64_t GetFinishedTimeNs() const;
+
+ double GetEventTimeMs() const;
+ uint64_t GetEventTimeNs() const;
+
+ void Wait() const;
+
+ cl_event event() const { return event_; }
+
+ bool is_valid() const { return event_ != nullptr; }
+
+ void SetName(const std::string &name);
+ std::string GetName() const { return name_; }
+
+private:
+ void Release();
+
+ cl_event event_ = nullptr;
+
+ std::string name_; // optional, for profiling mostly
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_EVENT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClImageFormat.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+cl_channel_order ToChannelOrder(int num_channels)
+{
+ switch (num_channels)
+ {
+ case 1:
+ return CL_R;
+ case 2:
+ return CL_RG;
+ case 3:
+ return CL_RGB;
+ case 4:
+ return CL_RGBA;
+ default:
+ return -1;
+ }
+}
+
+cl_channel_type ToImageChannelType(DataType data_type)
+{
+ switch (data_type)
+ {
+ case DataType::FLOAT32:
+ return CL_FLOAT;
+ case DataType::FLOAT16:
+ return CL_HALF_FLOAT;
+ default:
+ return -1;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_IMAGE_FORMAT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_IMAGE_FORMAT_H__
+
+#include "OpenclWrapper.h"
+#include "DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+cl_channel_order ToChannelOrder(int num_channels);
+
+cl_channel_type ToImageChannelType(DataType data_type);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_IMAGE_FORMAT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClKernel.h"
+
+#include "absl/strings/str_cat.h"
+#include "ClProgram.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id, int *result)
+{
+ size_t max_work_group_size;
+ cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
+ sizeof(size_t), &max_work_group_size, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
+ CLErrorCodeToString(error_code)));
+ }
+ *result = static_cast<int>(max_work_group_size);
+ return absl::OkStatus();
+}
+
+absl::Status GetKernelPrivateMemorySize(cl_kernel kernel, cl_device_id device_id, int *result)
+{
+ cl_ulong private_mem_size;
+ cl_int error_code = clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
+ sizeof(cl_ulong), &private_mem_size, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
+ CLErrorCodeToString(error_code)));
+ }
+ *result = static_cast<int>(private_mem_size);
+ return absl::OkStatus();
+}
+
+} // namespace
+
+CLKernel::CLKernel(CLKernel &&kernel)
+ : info_(kernel.info_), binding_counter_(kernel.binding_counter_),
+ function_name_(std::move(kernel.function_name_)), program_(kernel.program_),
+ kernel_(kernel.kernel_)
+{
+ kernel.kernel_ = nullptr;
+}
+
+CLKernel &CLKernel::operator=(CLKernel &&kernel)
+{
+ if (this != &kernel)
+ {
+ Release();
+ std::swap(info_, kernel.info_);
+ std::swap(binding_counter_, kernel.binding_counter_);
+ function_name_ = std::move(kernel.function_name_);
+ std::swap(program_, kernel.program_);
+ std::swap(kernel_, kernel.kernel_);
+ }
+ return *this;
+}
+
+CLKernel::~CLKernel() { Release(); }
+
+absl::Status CLKernel::ReInit() const
+{
+ clReleaseKernel(kernel_);
+ cl_kernel *kern_ptr = const_cast<cl_kernel *>(&kernel_);
+ int error_code;
+ *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
+ if (!kernel_ || error_code != CL_SUCCESS)
+ {
+ *kern_ptr = nullptr;
+ return absl::UnknownError(
+ absl::StrCat("Failed to create ", function_name_, CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+void CLKernel::Release()
+{
+ if (kernel_)
+ {
+ clReleaseKernel(kernel_);
+ clReleaseProgram(program_);
+ kernel_ = nullptr;
+ }
+}
+
+absl::Status CLKernel::CreateFromProgram(const CLProgram &program, const std::string &function_name)
+{
+ int error_code;
+ function_name_ = function_name;
+ kernel_ = clCreateKernel(program.program(), function_name.c_str(), &error_code);
+ if (!kernel_ || error_code != CL_SUCCESS)
+ {
+ kernel_ = nullptr;
+ return absl::UnknownError(
+ absl::StrCat("Failed to create ", function_name, CLErrorCodeToString(error_code)));
+ }
+
+ program_ = program.program();
+ clRetainProgram(program_);
+
+ RETURN_IF_ERROR(
+ GetKernelPrivateMemorySize(kernel_, program.GetDeviceId(), &info_.private_memory_size));
+ RETURN_IF_ERROR(
+ GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(), &info_.max_work_group_size));
+ return absl::OkStatus();
+}
+
+absl::Status CLKernel::SetMemory(int index, cl_mem memory)
+{
+ return SetBytes(index, &memory, sizeof(cl_mem));
+}
+
+absl::Status CLKernel::SetMemoryAuto(cl_mem memory)
+{
+ return SetBytesAuto(&memory, sizeof(cl_mem));
+}
+
+absl::Status CLKernel::SetBytes(int index, const void *ptr, int length) const
+{
+ const int error_code = clSetKernelArg(kernel_, index, length, ptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to set kernel arguments - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CLKernel::SetBytesAuto(const void *ptr, int length)
+{
+ const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
+ CLErrorCodeToString(error_code), "(at index - ",
+ binding_counter_, ")"));
+ }
+ binding_counter_++;
+ return absl::OkStatus();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_KERNEL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_KERNEL_H__
+
+#include <string>
+
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "ClProgram.h"
+#include "OpenclWrapper.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct KernelInfo
+{
+ int private_memory_size = 0;
+ int max_work_group_size = 0;
+};
+
+// Arguments binding to CLKernel can be manual or automatic
+// In manual you specify binding index explicitly
+// In automatic binding, index auto-incremented with every binding call
+// Also, if you use automatic mode you must call ResetBindingCounter
+// before parameters binding
+class CLKernel
+{
+public:
+ CLKernel() {}
+
+ // Move only
+ CLKernel(CLKernel &&kernel);
+ CLKernel &operator=(CLKernel &&kernel);
+ CLKernel(const CLKernel &) = delete;
+ CLKernel &operator=(const CLKernel &) = delete;
+
+ ~CLKernel();
+
+ cl_kernel kernel() const { return kernel_; }
+
+ absl::Status CreateFromProgram(const CLProgram &program, const std::string &function_name);
+
+ absl::Status SetMemory(int index, cl_mem memory);
+ absl::Status SetMemoryAuto(cl_mem memory);
+ template <typename T> absl::Status SetBytes(int index, const T &value) const
+ {
+ return SetBytes(index, static_cast<const void *>(&value), sizeof(T));
+ }
+ template <typename T> absl::Status SetBytesAuto(const T &value)
+ {
+ return SetBytesAuto(static_cast<const void *>(&value), sizeof(T));
+ }
+
+ int GetBindingCounter() const { return binding_counter_; }
+ void ResetBindingCounter() { binding_counter_ = 0; }
+
+ // Do not use this function
+ // workaround for Mali memory leak
+ absl::Status ReInit() const;
+
+ KernelInfo info_;
+
+private:
+ void Release();
+ absl::Status SetBytes(int index, const void *ptr, int length) const;
+ absl::Status SetBytesAuto(const void *ptr, int length);
+
+ int binding_counter_ = -1;
+
+ std::string function_name_ = "";
+ // reference to program from which kernel was created
+ cl_program program_ = nullptr;
+ cl_kernel kernel_ = nullptr;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_KERNEL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClMemory.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+cl_mem_flags ToClMemFlags(AccessType access_type)
+{
+ switch (access_type)
+ {
+ case AccessType::READ:
+ return CL_MEM_READ_ONLY;
+ case AccessType::WRITE:
+ return CL_MEM_WRITE_ONLY;
+ case AccessType::READ_WRITE:
+ return CL_MEM_READ_WRITE;
+ default:
+ throw std::runtime_error("Invalid AccessType");
+ }
+
+ return CL_MEM_READ_ONLY; // unreachable
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_MEMORY_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_MEMORY_H__
+
+#include <algorithm>
+
+#include "OpenclWrapper.h"
+#include "AccessType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// RAII wrapper for OpenCL memory object.
+//
+// Image is moveable but not copyable.
+class CLMemory
+{
+public:
+ // Creates invalid object.
+ CLMemory() : CLMemory(nullptr, false) {}
+
+ CLMemory(cl_mem memory, bool has_ownership) : memory_(memory), has_ownership_(has_ownership) {}
+
+ // Move-only
+ CLMemory(const CLMemory &) = delete;
+ CLMemory &operator=(const CLMemory &) = delete;
+ CLMemory(CLMemory &&image) : memory_(image.memory_), has_ownership_(image.has_ownership_)
+ {
+ image.memory_ = nullptr;
+ }
+
+ ~CLMemory() { Invalidate(); }
+
+ CLMemory &operator=(CLMemory &&image)
+ {
+ if (this != &image)
+ {
+ Invalidate();
+ std::swap(memory_, image.memory_);
+ has_ownership_ = image.has_ownership_;
+ }
+ return *this;
+ }
+
+ cl_mem memory() const { return memory_; }
+
+ bool is_valid() const { return memory_ != nullptr; }
+
+ // @return true if this object actually owns corresponding CL memory
+ // and manages it's lifetime.
+ bool has_ownership() const { return has_ownership_; }
+
+ cl_mem Release()
+ {
+ cl_mem to_return = memory_;
+ memory_ = nullptr;
+ return to_return;
+ }
+
+private:
+ void Invalidate()
+ {
+ if (memory_ && has_ownership_)
+ {
+ clReleaseMemObject(memory_);
+ }
+ memory_ = nullptr;
+ }
+
+ cl_mem memory_ = nullptr;
+ bool has_ownership_ = false;
+};
+
+cl_mem_flags ToClMemFlags(AccessType access_type);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_MEMORY_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClProgram.h"
+
+#include <cstdint>
+#include <cstring>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "Util.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::string GetProgramBuildInfo(cl_program program, cl_device_id id, cl_program_build_info info)
+{
+ size_t size;
+ cl_int error_code = clGetProgramBuildInfo(program, id, info, 0, nullptr, &size);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::StrCat("Failed to GetProgramBuildInfo - ", CLErrorCodeToString(error_code));
+ }
+
+ std::string result(size - 1, 0);
+ error_code = clGetProgramBuildInfo(program, id, info, size, &result[0], nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::StrCat("Failed to GetProgramBuildInfo - ", CLErrorCodeToString(error_code));
+ }
+ return result;
+}
+
+absl::Status GetBinarySize(cl_program program, size_t *binary_size)
+{
+ cl_int error_code =
+ clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), binary_size, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to get program binary size - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status BuildProgram(cl_program program, const CLDevice &device,
+ const std::string &compiler_options)
+{
+ const int error_code =
+ clBuildProgram(program, 0, nullptr, compiler_options.c_str(), nullptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to build program executable - ", CLErrorCodeToString(error_code),
+ GetProgramBuildInfo(program, device.id(), CL_PROGRAM_BUILD_LOG)));
+ }
+
+ return absl::OkStatus();
+}
+
+std::string CompilerOptionToString(const CLDevice &device, CompilerOptions option)
+{
+ switch (option)
+ {
+ case CompilerOptions::ADRENO_FULL_SIMD_LINE:
+ if (device.info_.adreno_info.gpu_version < 500)
+ {
+ return "-qcom-accelerate-16-bit";
+ }
+ else
+ {
+ return "-qcom-accelerate-16-bit=true";
+ }
+ case CompilerOptions::ADRENO_MORE_WAVES:
+ if (device.info_.adreno_info.gpu_version >= 500)
+ {
+ return "-qcom-accelerate-16-bit=false";
+ }
+ else
+ {
+ return "";
+ }
+ case CompilerOptions::POWERVR_FP16:
+ return "-cl-fast-relaxed-math";
+ case CompilerOptions::CL_OPT_DISABLE:
+ return "-cl-opt-disable";
+ case CompilerOptions::CL_2_0:
+ return "-cl-std=CL2.0";
+ case CompilerOptions::CL_3_0:
+ return "-cl-std=CL3.0";
+ }
+ return "";
+}
+
+} // namespace
+
+std::string CompilerOptionsToString(const CLDevice &device,
+ const std::vector<CompilerOptions> &compiler_options)
+{
+ std::string result;
+ for (auto option : compiler_options)
+ {
+ absl::StrAppend(&result, CompilerOptionToString(device, option), " ");
+ }
+ return result;
+}
+
+CLProgram::CLProgram(cl_program program, cl_device_id device_id)
+ : program_(program), device_id_(device_id)
+{
+}
+
+CLProgram::CLProgram(CLProgram &&program)
+ : program_(program.program_), device_id_(program.device_id_)
+{
+ program.program_ = nullptr;
+}
+
+CLProgram &CLProgram::operator=(CLProgram &&program)
+{
+ if (this != &program)
+ {
+ Release();
+ std::swap(program_, program.program_);
+ std::swap(device_id_, program.device_id_);
+ }
+ return *this;
+}
+
+CLProgram::~CLProgram() { Release(); }
+
+void CLProgram::Release()
+{
+ if (program_)
+ {
+ clReleaseProgram(program_);
+ program_ = nullptr;
+ }
+}
+
+absl::Status CLProgram::GetBinary(std::vector<uint8_t> *result) const
+{
+ size_t binary_size;
+ RETURN_IF_ERROR(GetBinarySize(program_, &binary_size));
+ result->resize(result->size() + binary_size);
+ uint8_t *binary_ptr = result->data() + result->size() - binary_size;
+ cl_int error_code =
+ clGetProgramInfo(program_, CL_PROGRAM_BINARIES, binary_size, &binary_ptr, nullptr);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to get program binary - ", CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CreateCLProgram(const std::string &code, const std::string &compiler_options,
+ const CLContext &context, const CLDevice &device, CLProgram *result)
+{
+ int error_code;
+ const char *source = code.c_str();
+
+ cl_program program =
+ clCreateProgramWithSource(context.context(), 1, &source, nullptr, &error_code);
+ if (!program || error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to create compute program - ", CLErrorCodeToString(error_code)));
+ }
+
+ *result = CLProgram(program, device.id());
+ RETURN_IF_ERROR(BuildProgram(program, device, compiler_options));
+ return absl::OkStatus();
+}
+
+absl::Status CreateCLProgramFromBinary(const CLContext &context, const CLDevice &device,
+ absl::Span<const uint8_t> binary, CLProgram *result)
+{
+ cl_int binary_status;
+ cl_int error_code;
+ cl_device_id devices_list[] = {device.id()};
+ size_t binary_size = binary.size();
+ const uint8_t *binary_pointer = binary.data();
+ cl_program program = clCreateProgramWithBinary(context.context(), 1, devices_list, &binary_size,
+ &binary_pointer, &binary_status, &error_code);
+ if (binary_status != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat(
+ "Something wrong with binary after clCreateProgramWithBinary - ", binary_status));
+ }
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(
+ absl::StrCat("Failed to create program - ", CLErrorCodeToString(error_code)));
+ }
+ *result = CLProgram(program, device.id());
+ return BuildProgram(program, device, "");
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_CL_PROGRAM_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_CL_PROGRAM_H__
+
+#include <cstdint>
+#include <vector>
+
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "OpenclWrapper.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class CompilerOptions
+{
+ // ADRENO_FULL_SIMD_LINE:
+ // Adreno can have 2 sizes for SIMD size.
+ // On Adreno 4xx/5xx it is 32/64, on Adreno6xx it is 64/128.
+ // Some our algorithms actually rely on exact size, for example on full
+ // SIMD size, so we need this define.
+ // This define is actually -qcom-accelerate-16-bit, but it controls SIMD
+ // size.
+ ADRENO_FULL_SIMD_LINE,
+ ADRENO_MORE_WAVES,
+ POWERVR_FP16,
+ CL_OPT_DISABLE,
+ CL_2_0,
+ CL_3_0,
+};
+
+std::string CompilerOptionsToString(const CLDevice &device,
+ const std::vector<CompilerOptions> &compiler_options);
+
+class CLProgram
+{
+public:
+ CLProgram() {}
+ CLProgram(cl_program program, cl_device_id device_id);
+
+ // Move only
+ CLProgram(CLProgram &&program);
+ CLProgram &operator=(CLProgram &&program);
+ CLProgram(const CLProgram &) = delete;
+ CLProgram &operator=(const CLProgram &) = delete;
+
+ ~CLProgram();
+
+ cl_program program() const { return program_; }
+
+ // Return the cl_device_id associated with the program object.
+ // This can be the device associated with context on which the program object
+ // has been created or can be device that was specified when a program object
+ // was created using clCreateProgramWithBinary.
+ cl_device_id GetDeviceId() const { return device_id_; }
+
+ absl::Status GetBinary(std::vector<uint8_t> *result) const;
+
+private:
+ void Release();
+
+ cl_program program_ = nullptr;
+
+ // reference
+ cl_device_id device_id_ = nullptr;
+};
+
+absl::Status CreateCLProgram(const std::string &code, const std::string &compiler_options,
+ const CLContext &context, const CLDevice &device, CLProgram *result);
+
+absl::Status CreateCLProgramFromBinary(const CLContext &context, const CLDevice &device,
+ absl::Span<const uint8_t> binary, CLProgram *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_CL_PROGRAM_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DataType.h"
+
+#include <stddef.h>
+#include <string>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+size_t SizeOf(DataType data_type)
+{
+ switch (data_type)
+ {
+ case DataType::UINT8:
+ case DataType::INT8:
+ return 1;
+ case DataType::FLOAT16:
+ case DataType::INT16:
+ case DataType::UINT16:
+ return 2;
+ case DataType::FLOAT32:
+ case DataType::INT32:
+ case DataType::UINT32:
+ return 4;
+ case DataType::FLOAT64:
+ case DataType::INT64:
+ case DataType::UINT64:
+ return 8;
+ case DataType::UNKNOWN:
+ return 0;
+ }
+ return 0;
+}
+
+std::string ToString(DataType data_type)
+{
+ switch (data_type)
+ {
+ case DataType::FLOAT16:
+ return "float16";
+ case DataType::FLOAT32:
+ return "float32";
+ case DataType::FLOAT64:
+ return "float64";
+ case DataType::INT16:
+ return "int16";
+ case DataType::INT32:
+ return "int32";
+ case DataType::INT64:
+ return "int64";
+ case DataType::INT8:
+ return "int8";
+ case DataType::UINT16:
+ return "uint16";
+ case DataType::UINT32:
+ return "uint32";
+ case DataType::UINT64:
+ return "uint64";
+ case DataType::UINT8:
+ return "uint8";
+ case DataType::UNKNOWN:
+ return "unknown";
+ }
+ return "undefined";
+}
+
+std::string ToCLDataType(DataType data_type, int vec_size)
+{
+ const std::string postfix = vec_size == 1 ? "" : std::to_string(vec_size);
+ switch (data_type)
+ {
+ case DataType::FLOAT16:
+ return "half" + postfix;
+ case DataType::FLOAT32:
+ return "float" + postfix;
+ case DataType::FLOAT64:
+ return "double" + postfix;
+ case DataType::INT16:
+ return "short" + postfix;
+ case DataType::INT32:
+ return "int" + postfix;
+ case DataType::INT64:
+ return "long" + postfix;
+ case DataType::INT8:
+ return "char" + postfix;
+ case DataType::UINT16:
+ return "ushort" + postfix;
+ case DataType::UINT32:
+ return "uint" + postfix;
+ case DataType::UINT64:
+ return "ulong" + postfix;
+ case DataType::UINT8:
+ return "uchar" + postfix;
+ case DataType::UNKNOWN:
+ return "unknown";
+ }
+ return "undefined";
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_DATA_TYPE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_DATA_TYPE_H__
+
+#include <stddef.h>
+#include <string>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class DataType
+{
+ UNKNOWN = 0,
+ FLOAT16 = 1,
+ FLOAT32 = 2,
+ FLOAT64 = 3,
+ UINT8 = 4,
+ INT8 = 5,
+ UINT16 = 6,
+ INT16 = 7,
+ UINT32 = 8,
+ INT32 = 9,
+ UINT64 = 10,
+ INT64 = 11,
+};
+
+size_t SizeOf(DataType type);
+
+std::string ToString(DataType t);
+
+std::string ToCLDataType(DataType data_type, int vec_size = 1);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_DATA_TYPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DeviceInfo.h"
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+namespace
+{
+// check that gpu_version belong to range min_version-max_version
+// min_version is included and max_version is excluded.
+bool IsGPUVersionInRange(int gpu_version, int min_version, int max_version)
+{
+ return gpu_version >= min_version && gpu_version < max_version;
+}
+
+MaliGPU GetMaliGPUVersion(const std::string &device_name)
+{
+ const std::map<std::string, MaliGPU> kMapping = {
+ {"T604", MaliGPU::T604}, {"T622", MaliGPU::T622}, {"T624", MaliGPU::T624},
+ {"T628", MaliGPU::T628}, {"T658", MaliGPU::T658}, {"T678", MaliGPU::T678},
+ {"T720", MaliGPU::T720}, {"T760", MaliGPU::T760}, {"T820", MaliGPU::T820},
+ {"T830", MaliGPU::T830}, {"T860", MaliGPU::T860}, {"T880", MaliGPU::T880},
+ {"G31", MaliGPU::G31}, {"G51", MaliGPU::G51}, {"G71", MaliGPU::G71},
+ {"G52", MaliGPU::G52}, {"G72", MaliGPU::G72}, {"G76", MaliGPU::G76},
+ {"G57", MaliGPU::G57}, {"G77", MaliGPU::G77}, {"G68", MaliGPU::G68},
+ {"G78", MaliGPU::G78},
+ };
+ for (const auto &v : kMapping)
+ {
+ if (device_name.find(v.first) != std::string::npos)
+ {
+ return v.second;
+ }
+ }
+ return MaliGPU::UNKNOWN;
+}
+
+} // namespace
+
+// There is no rule for gpu version encoding, but we found these samples:
+// Version: OpenCL C 2.0 Adreno(TM) 540 // Pixel 2
+// Version: OpenCL C 2.0 Adreno(TM) 630 // Sony Compact XZ2
+// Version: OpenCL C 2.0 Adreno(TM) 630 // Pixel 3
+// Version: OpenCL C 2.0 Adreno(TM) 540 // Samsung S8
+// Version: OpenCL C 1.2 Adreno(TM) 430 // HTC One M9
+// Version: OpenCL C 2.0 Adreno(TM) 530 // Samsung S7 Edge
+// Version: OpenCL C 1.2 Adreno(TM) 405 // Motorola Moto G(4)
+// After the number string ends.
+// It is assumed that the <vendor-specific information> for Adreno GPUs has
+// the following format:
+// <text?><space?>Adreno(TM)<space><text?><version>
+// Returns -1 if vendor-specific information cannot be parsed
+int GetAdrenoGPUVersion(const std::string &gpu_version)
+{
+ const std::string gpu = absl::AsciiStrToLower(gpu_version);
+ const std::vector<absl::string_view> words = absl::StrSplit(gpu, ' ');
+ size_t i = 0;
+ for (; i < words.size(); ++i)
+ {
+ if (words[i].find("adreno") != words[i].npos)
+ {
+ break;
+ }
+ }
+ i += 1;
+ for (; i < words.size(); ++i)
+ {
+ int number;
+ bool is_number = absl::SimpleAtoi(words[i], &number);
+ // Adreno GPUs starts from 2xx, but opencl support should be only from 3xx
+ if (is_number && number >= 300)
+ {
+ return number;
+ }
+ }
+ return -1;
+}
+
+std::string VendorToString(Vendor v)
+{
+ switch (v)
+ {
+ case Vendor::kQualcomm:
+ return "Qualcomm";
+ case Vendor::kMali:
+ return "Mali";
+ case Vendor::kPowerVR:
+ return "PowerVR";
+ case Vendor::kNvidia:
+ return "NVIDIA";
+ case Vendor::kAMD:
+ return "AMD";
+ case Vendor::kIntel:
+ return "Intel";
+ case Vendor::kUnknown:
+ return "unknown vendor";
+ default:
+ return "Error";
+ }
+}
+
+std::string OpenCLVersionToString(OpenCLVersion version)
+{
+ switch (version)
+ {
+ case OpenCLVersion::CL_1_0:
+ return "1.0";
+ case OpenCLVersion::CL_1_1:
+ return "1.1";
+ case OpenCLVersion::CL_1_2:
+ return "1.2";
+ case OpenCLVersion::CL_2_0:
+ return "2.0";
+ case OpenCLVersion::CL_2_1:
+ return "2.1";
+ case OpenCLVersion::CL_2_2:
+ return "2.2";
+ case OpenCLVersion::CL_3_0:
+ return "3.0";
+ default:
+ return "Error";
+ }
+}
+
+AdrenoInfo::AdrenoInfo(const std::string &device_version)
+ : gpu_version(GetAdrenoGPUVersion(device_version))
+{
+}
+
+int AdrenoInfo::GetMaximumWavesCount() const
+{
+ if (gpu_version < 400)
+ {
+ return -1; // Adreno 3xx does not support it currently
+ }
+ else if (gpu_version >= 400 && gpu_version < 500)
+ {
+ return -1; // Adreno 4xx does not support it currently
+ }
+ else if (gpu_version >= 500 && gpu_version < 600)
+ {
+ return -1; // Adreno 5xx does not support it currently
+ }
+ else if (gpu_version >= 600 && gpu_version < 700)
+ {
+ return gpu_version == 640 ? 30 : 16;
+ }
+ else
+ {
+ return -1; // Adreno 7xx and higher does not exist yet
+ }
+}
+
+int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const
+{
+ if (gpu_version < 400)
+ {
+ return -1; // Adreno 3xx does not support it currently
+ }
+ else if (gpu_version >= 400 && gpu_version < 500)
+ {
+ return -1; // Adreno 4xx does not support it currently
+ }
+ else if (gpu_version >= 500 && gpu_version < 600)
+ {
+ return -1; // Adreno 5xx does not support it currently
+ }
+ else if (gpu_version >= 600 && gpu_version < 700)
+ {
+ return gpu_version == 640 ? 128 * 144 * 16 : 128 * 96 * 16;
+ }
+ else
+ {
+ return -1; // Adreno 7xx and higher does not exist yet
+ }
+}
+
+int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread, bool full_wave) const
+{
+ const int register_usage_per_wave = GetWaveSize(full_wave) * register_footprint_per_tread;
+ const int possible_waves_count = GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
+ return std::min(possible_waves_count, GetMaximumWavesCount());
+}
+
+int AdrenoInfo::GetWaveSize(bool full_wave) const
+{
+ if (gpu_version < 400)
+ {
+ return -1; // Adreno 3xx does not support it currently
+ }
+ else if (gpu_version < 600)
+ {
+ return full_wave ? 64 : 32;
+ }
+ else
+ {
+ return full_wave ? 128 : 64;
+ }
+}
+
+MaliInfo::MaliInfo(const std::string &device_name) : gpu_version(GetMaliGPUVersion(device_name)) {}
+
+bool MaliInfo::IsMaliT6xx() const
+{
+ return gpu_version == MaliGPU::T604 || gpu_version == MaliGPU::T622 ||
+ gpu_version == MaliGPU::T624 || gpu_version == MaliGPU::T628 ||
+ gpu_version == MaliGPU::T658 || gpu_version == MaliGPU::T678;
+}
+
+bool MaliInfo::IsMaliT7xx() const
+{
+ return gpu_version == MaliGPU::T720 || gpu_version == MaliGPU::T760;
+}
+
+bool MaliInfo::IsMaliT8xx() const
+{
+ return gpu_version == MaliGPU::T820 || gpu_version == MaliGPU::T830 ||
+ gpu_version == MaliGPU::T860 || gpu_version == MaliGPU::T880;
+}
+
+bool MaliInfo::IsMidgard() const { return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx(); }
+
+bool MaliInfo::IsBifrostGen1() const
+{
+ return gpu_version == MaliGPU::G31 || gpu_version == MaliGPU::G51 || gpu_version == MaliGPU::G71;
+}
+
+bool MaliInfo::IsBifrostGen2() const
+{
+ return gpu_version == MaliGPU::G52 || gpu_version == MaliGPU::G72;
+}
+
+bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGPU::G76; }
+
+bool MaliInfo::IsBifrost() const { return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3(); }
+
+bool MaliInfo::IsValhall() const
+{
+ return gpu_version == MaliGPU::G57 || gpu_version == MaliGPU::G77 ||
+ gpu_version == MaliGPU::G68 || gpu_version == MaliGPU::G78;
+}
+
+bool DeviceInfo::SupportsTextureArray() const { return cl_version >= OpenCLVersion::CL_1_2; }
+
+bool DeviceInfo::SupportsImageBuffer() const { return cl_version >= OpenCLVersion::CL_1_2; }
+
+bool DeviceInfo::SupportsImage3D() const
+{
+ if (vendor == Vendor::kMali)
+ {
+ // On Mali T880 read_imageh doesn't compile with image3d_t
+ return false;
+ }
+ return supports_image3d_writes;
+}
+
+bool DeviceInfo::SupportsFloatImage2D(DataType data_type, int channels) const
+{
+ if (channels == 1)
+ {
+ return data_type == DataType::FLOAT32 ? supports_r_f32_tex2d : supports_r_f16_tex2d;
+ }
+ else if (channels == 2)
+ {
+ return data_type == DataType::FLOAT32 ? supports_rg_f32_tex2d : supports_rg_f16_tex2d;
+ }
+ else if (channels == 3)
+ {
+ return data_type == DataType::FLOAT32 ? supports_rgb_f32_tex2d : supports_rgb_f16_tex2d;
+ }
+ else if (channels == 4)
+ {
+ return data_type == DataType::FLOAT32 ? supports_rgba_f32_tex2d : supports_rgba_f16_tex2d;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+bool DeviceInfo::SupportsOneLayerTextureArray() const
+{
+ return !IsAdreno() || adreno_info.support_one_layer_texture_array;
+}
+
+bool DeviceInfo::SupportsExtension(const std::string &extension) const
+{
+ for (const auto &ext : extensions)
+ {
+ if (ext == extension)
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DeviceInfo::IsCL20OrHigher() const
+{
+ return cl_version != OpenCLVersion::CL_1_0 && cl_version != OpenCLVersion::CL_1_1 &&
+ cl_version != OpenCLVersion::CL_1_2;
+}
+
+bool DeviceInfo::SupportsSubGroupWithSize(int sub_group_size) const
+{
+ for (auto subgroup_size : supported_subgroup_sizes)
+ {
+ if (sub_group_size == subgroup_size)
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool DeviceInfo::IsAdreno() const { return vendor == Vendor::kQualcomm; }
+
+bool DeviceInfo::IsAdreno3xx() const
+{
+ return IsAdreno() && IsGPUVersionInRange(adreno_info.gpu_version, 300, 400);
+}
+
+bool DeviceInfo::IsAdreno4xx() const
+{
+ return IsAdreno() && IsGPUVersionInRange(adreno_info.gpu_version, 400, 500);
+}
+
+bool DeviceInfo::IsAdreno5xx() const
+{
+ return IsAdreno() && IsGPUVersionInRange(adreno_info.gpu_version, 500, 600);
+}
+
+bool DeviceInfo::IsAdreno6xx() const
+{
+ return IsAdreno() && IsGPUVersionInRange(adreno_info.gpu_version, 600, 700);
+}
+
+bool DeviceInfo::IsAdreno6xxOrHigher() const
+{
+ return IsAdreno() && adreno_info.gpu_version >= 600;
+}
+
+bool DeviceInfo::IsPowerVR() const { return vendor == Vendor::kPowerVR; }
+
+bool DeviceInfo::IsNvidia() const { return vendor == Vendor::kNvidia; }
+
+bool DeviceInfo::IsMali() const { return vendor == Vendor::kMali; }
+
+bool DeviceInfo::IsAMD() const { return vendor == Vendor::kAMD; }
+
+bool DeviceInfo::IsIntel() const { return vendor == Vendor::kIntel; }
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_DEVICE_INFO_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_DEVICE_INFO_H__
+
+#include <string>
+#include <vector>
+
+#include "DataType.h"
+
+// for use only in device_info.cc, but keep here to make tests
+int GetAdrenoGPUVersion(const std::string &gpu_version);
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class Vendor
+{
+ kQualcomm,
+ kMali,
+ kPowerVR,
+ kNvidia,
+ kAMD,
+ kIntel,
+ kUnknown
+};
+std::string VendorToString(Vendor v);
+
+enum class OpenCLVersion
+{
+ UNKNOWN,
+ CL_1_0,
+ CL_1_1,
+ CL_1_2,
+ CL_2_0,
+ CL_2_1,
+ CL_2_2,
+ CL_3_0
+};
+std::string OpenCLVersionToString(OpenCLVersion version);
+
+struct AdrenoInfo
+{
+ AdrenoInfo() = default;
+ explicit AdrenoInfo(const std::string &device_version);
+ int gpu_version = -1; // can be, for example, 405/430/540/530/630 etc.
+
+ // This function returns some not very documented physical parameter of
+ // Adreno6xx GPU.
+ // We obtained it using Snapdragon Profiler.
+ int GetMaximumWavesCount() const;
+
+ // returns amount of register memory per CU(Compute Unit) in bytes.
+ int GetRegisterMemorySizePerComputeUnit() const;
+
+ // returns maximum possible amount of waves based on register usage.
+ int GetMaximumWavesCount(int register_footprint_per_tread, bool full_wave = true) const;
+
+ int GetWaveSize(bool full_wave) const;
+
+ // Not supported on some Adreno devices with specific driver version.
+ // b/131099086
+ bool support_one_layer_texture_array = true;
+};
+
+enum class MaliGPU
+{
+ T604,
+ T622,
+ T624,
+ T628,
+ T658,
+ T678,
+ T720,
+ T760,
+ T820,
+ T830,
+ T860,
+ T880,
+ G31,
+ G51,
+ G71,
+ G52,
+ G72,
+ G76,
+ G57,
+ G77,
+ G68,
+ G78,
+ UNKNOWN
+};
+
+struct MaliInfo
+{
+ MaliInfo() = default;
+ explicit MaliInfo(const std::string &device_name);
+ MaliGPU gpu_version = MaliGPU::UNKNOWN;
+
+ bool IsMaliT6xx() const;
+ bool IsMaliT7xx() const;
+ bool IsMaliT8xx() const;
+ bool IsMidgard() const;
+ bool IsBifrostGen1() const;
+ bool IsBifrostGen2() const;
+ bool IsBifrostGen3() const;
+ bool IsBifrost() const;
+ bool IsValhall() const;
+};
+
+struct DeviceInfo
+{
+ DeviceInfo() = default;
+
+ bool IsAdreno() const;
+ bool IsAdreno3xx() const;
+ bool IsAdreno4xx() const;
+ bool IsAdreno5xx() const;
+ bool IsAdreno6xx() const;
+ bool IsAdreno6xxOrHigher() const;
+ bool IsPowerVR() const;
+ bool IsNvidia() const;
+ bool IsMali() const;
+ bool IsAMD() const;
+ bool IsIntel() const;
+
+ bool SupportsTextureArray() const;
+ bool SupportsImageBuffer() const;
+ bool SupportsImage3D() const;
+
+ bool SupportsFloatImage2D(DataType data_type, int channels) const;
+
+ // To track bug on some Adreno. b/131099086
+ bool SupportsOneLayerTextureArray() const;
+
+ bool SupportsExtension(const std::string &extension) const;
+ bool IsCL20OrHigher() const;
+ bool SupportsSubGroupWithSize(int sub_group_size) const;
+
+ std::vector<std::string> extensions;
+ bool supports_fp16 = false;
+ bool supports_image3d_writes = false;
+ Vendor vendor = Vendor::kUnknown;
+ OpenCLVersion cl_version = OpenCLVersion::UNKNOWN;
+ int compute_units_count = 0;
+ uint64_t buffer_max_size = 0;
+ uint64_t image2d_max_width = 0;
+ uint64_t image2d_max_height = 0;
+ uint64_t image_buffer_max_size = 0;
+ uint64_t image_array_max_layers = 0;
+ uint64_t image3d_max_width = 0;
+ uint64_t image3d_max_height = 0;
+ uint64_t image3d_max_depth = 0;
+ int max_work_group_size_x = 0;
+ int max_work_group_size_y = 0;
+ int max_work_group_size_z = 0;
+ std::vector<int> supported_subgroup_sizes;
+
+ // rtn is ROUND_TO_NEAREST
+ // with rtn precision is much better then with rtz (ROUND_TO_ZERO)
+ // Adreno 3xx supports only rtz, Adreno 4xx and more support rtn
+ // Mali from T6xx supports rtn
+ // PowerVR supports only rtz
+ bool supports_fp32_rtn = false;
+ bool supports_fp16_rtn = false;
+
+ bool supports_r_f16_tex2d = false;
+ bool supports_rg_f16_tex2d = false;
+ bool supports_rgb_f16_tex2d = false;
+ bool supports_rgba_f16_tex2d = false;
+
+ bool supports_r_f32_tex2d = false;
+ bool supports_rg_f32_tex2d = false;
+ bool supports_rgb_f32_tex2d = false;
+ bool supports_rgba_f32_tex2d = false;
+
+ AdrenoInfo adreno_info;
+ MaliInfo mali_info;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_DEVICE_INFO_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Environment.h"
+
+#include <string>
+#include <vector>
+
+#include "Util.h"
+#include "Shape.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+Environment::Environment(CLDevice &&device, CLContext &&context, CLCommandQueue &&queue,
+ ProfilingCommandQueue &&profiling_queue)
+ : device_(std::move(device)), context_(std::move(context)), queue_(std::move(queue)),
+ profiling_queue_(std::move(profiling_queue))
+{
+}
+
+Environment::Environment(Environment &&environment)
+ : device_(std::move(environment.device_)), context_(std::move(environment.context_)),
+ queue_(std::move(environment.queue_)),
+ profiling_queue_(std::move(environment.profiling_queue_)),
+ program_cache_(std::move(environment.program_cache_))
+{
+}
+
+Environment &Environment::operator=(Environment &&environment)
+{
+ if (this != &environment)
+ {
+ device_ = std::move(environment.device_);
+ context_ = std::move(environment.context_);
+ queue_ = std::move(environment.queue_);
+ profiling_queue_ = std::move(environment.profiling_queue_);
+ program_cache_ = std::move(environment.program_cache_);
+ }
+ return *this;
+}
+
+absl::Status Environment::Init()
+{
+ if (device().IsAdreno() && device().SupportsTextureArray())
+ {
+ // Some Adreno < 600 have bug with one layer texture array. b/131099086
+ // If we have one layer texture array and will write smt from kernel to this
+ // texture, we will get zeroes instead of actual values.
+ // The same kernel will work, if we use texture array with more than one
+ // layer.
+ if (device().info_.adreno_info.gpu_version < 600)
+ {
+ GetDevicePtr()->DisableOneLayerTextureArray();
+ }
+ }
+ return absl::OkStatus();
+}
+
+void Environment::SetHighPerformance() const
+{
+ // TODO(sorokin) use cl_perf_hint if available
+}
+
+void Environment::SetDefaultPerformance() const
+{
+ // TODO(sorokin) use cl_perf_hint if available
+}
+
+void Environment::SetLowPerformance() const
+{
+ // TODO(sorokin) use cl_perf_hint if available
+}
+
+std::vector<CalculationsPrecision> Environment::GetSupportedPrecisions() const
+{
+ std::vector<CalculationsPrecision> precisions;
+ for (CalculationsPrecision precision :
+ {CalculationsPrecision::F32, CalculationsPrecision::F32_F16, CalculationsPrecision::F16})
+ {
+ if (IsSupported(precision))
+ {
+ precisions.push_back(precision);
+ }
+ }
+ return precisions;
+}
+
+bool Environment::IsSupported(CalculationsPrecision precision) const
+{
+ switch (precision)
+ {
+ case CalculationsPrecision::F32_F16:
+ case CalculationsPrecision::F16:
+ return device_.SupportsFP16();
+ case CalculationsPrecision::F32:
+ return true;
+ }
+ return false;
+}
+
+std::vector<TensorStorageType> Environment::GetSupportedStorages() const
+{
+ std::vector<TensorStorageType> storage_types;
+ for (auto storage_type :
+ {TensorStorageType::TEXTURE_2D, TensorStorageType::BUFFER, TensorStorageType::TEXTURE_ARRAY,
+ TensorStorageType::IMAGE_BUFFER, TensorStorageType::TEXTURE_3D})
+ {
+ if (IsSupported(storage_type))
+ {
+ storage_types.push_back(storage_type);
+ }
+ }
+ return storage_types;
+}
+
+std::vector<TensorStorageType> Environment::GetSupportedStoragesWithHWZeroClampSupport() const
+{
+ std::vector<TensorStorageType> storage_types;
+ for (auto storage_type : {TensorStorageType::TEXTURE_2D, TensorStorageType::TEXTURE_ARRAY,
+ TensorStorageType::TEXTURE_3D})
+ {
+ if (IsSupported(storage_type))
+ {
+ storage_types.push_back(storage_type);
+ }
+ }
+ return storage_types;
+}
+
+bool Environment::IsSupported(TensorStorageType storage_type) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::TEXTURE_2D:
+ return !device_.IsAMD();
+ case TensorStorageType::BUFFER:
+ return true;
+ case TensorStorageType::TEXTURE_ARRAY:
+ return !device_.IsAMD() && device_.SupportsTextureArray();
+ case TensorStorageType::IMAGE_BUFFER:
+ return (device_.IsAdreno() || device_.IsAMD() || device_.IsNvidia()) &&
+ device_.SupportsImageBuffer();
+ case TensorStorageType::TEXTURE_3D:
+ return !device_.IsAMD() && device_.SupportsImage3D();
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return false;
+ case TensorStorageType::UNKNOWN:
+ return false;
+ }
+ return false;
+}
+
+TensorStorageType GetFastestStorageType(const DeviceInfo &gpu_info)
+{
+ if (gpu_info.IsAdreno())
+ {
+ if (gpu_info.IsAdreno6xxOrHigher())
+ {
+ return TensorStorageType::TEXTURE_ARRAY;
+ }
+ else
+ {
+ return TensorStorageType::TEXTURE_2D;
+ }
+ }
+ else if (gpu_info.IsPowerVR())
+ {
+ return TensorStorageType::TEXTURE_2D;
+ }
+ else if (gpu_info.IsMali())
+ {
+ const MaliInfo mali_info = gpu_info.mali_info;
+ if (mali_info.IsMaliT8xx() || mali_info.IsBifrostGen3() || mali_info.IsValhall())
+ {
+ return TensorStorageType::TEXTURE_2D;
+ }
+ else
+ {
+ return TensorStorageType::BUFFER;
+ }
+ }
+ else if (gpu_info.IsNvidia())
+ {
+ return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
+ : TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsAMD())
+ {
+ return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
+ : TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsIntel())
+ {
+ return TensorStorageType::BUFFER;
+ }
+ return TensorStorageType::BUFFER;
+}
+
+TensorStorageType GetStorageTypeWithMinimalMemoryConsumption(const DeviceInfo &gpu_info)
+{
+ if (gpu_info.IsAdreno())
+ {
+ if (gpu_info.IsAdreno3xx() || gpu_info.IsAdreno4xx())
+ {
+ return TensorStorageType::BUFFER;
+ }
+ else
+ {
+ return TensorStorageType::IMAGE_BUFFER;
+ }
+ }
+ else if (gpu_info.IsPowerVR())
+ {
+ return TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsMali())
+ {
+ return TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsNvidia())
+ {
+ return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
+ : TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsAMD())
+ {
+ return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER
+ : TensorStorageType::BUFFER;
+ }
+ else if (gpu_info.IsIntel())
+ {
+ return TensorStorageType::BUFFER;
+ }
+ return TensorStorageType::BUFFER;
+}
+
+absl::Status CreateEnvironment(Environment *result)
+{
+ CLDevice gpu;
+ RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu));
+
+ CLContext context;
+ RETURN_IF_ERROR(CreateCLContext(gpu, &context));
+ CLCommandQueue queue;
+ RETURN_IF_ERROR(CreateCLCommandQueue(gpu, context, &queue));
+ ProfilingCommandQueue profiling_queue;
+ RETURN_IF_ERROR(CreateProfilingCommandQueue(gpu, context, &profiling_queue));
+
+ *result =
+ Environment(std::move(gpu), std::move(context), std::move(queue), std::move(profiling_queue));
+ return result->Init();
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_ENVIRONMENT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_ENVIRONMENT_H__
+
+#include "ClCommandQueue.h"
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "DeviceInfo.h"
+#include "Precision.h"
+#include "TensorType.h"
+#include "DataType.h"
+#include "ProgramCache.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class Environment
+{
+public:
+ Environment() = default;
+ explicit Environment(CLDevice &&device, CLContext &&context, CLCommandQueue &&queue,
+ ProfilingCommandQueue &&profiling_queue);
+ // Move only
+ Environment(Environment &&environment);
+ Environment &operator=(Environment &&environment);
+ Environment(const Environment &) = delete;
+ Environment &operator=(const Environment &) = delete;
+
+ const CLDevice &device() const { return device_; }
+ CLDevice *GetDevicePtr() { return &device_; }
+ const CLDevice *GetDevicePtr() const { return &device_; }
+ CLContext &context() { return context_; }
+ CLCommandQueue *queue() { return &queue_; }
+ ProfilingCommandQueue *profiling_queue() { return &profiling_queue_; }
+ ProgramCache *program_cache() { return &program_cache_; }
+ const ProgramCache *program_cache() const { return &program_cache_; }
+
+ std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
+ bool IsSupported(CalculationsPrecision precision) const;
+ std::vector<TensorStorageType> GetSupportedStorages() const;
+ // returns storage types that support zero clamping when reading OOB in HW
+ // (Height/Width) dimensions.
+ std::vector<TensorStorageType> GetSupportedStoragesWithHWZeroClampSupport() const;
+ bool IsSupported(TensorStorageType storage_type) const;
+
+ absl::Status Init();
+
+ void SetHighPerformance() const;
+ void SetDefaultPerformance() const;
+ void SetLowPerformance() const; // for energy saving
+
+private:
+ CLDevice device_;
+ CLContext context_;
+ CLCommandQueue queue_;
+ ProfilingCommandQueue profiling_queue_;
+ ProgramCache program_cache_;
+};
+
+TensorStorageType GetFastestStorageType(const DeviceInfo &gpu_info);
+TensorStorageType GetStorageTypeWithMinimalMemoryConsumption(const DeviceInfo &gpu_info);
+
+absl::Status CreateEnvironment(Environment *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_ENVIRONMENT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GpuObject.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::string MemoryTypeToCLType(MemoryType type)
+{
+ switch (type)
+ {
+ case MemoryType::GLOBAL:
+ return "__global";
+ case MemoryType::CONSTANT:
+ return "__constant";
+ break;
+ case MemoryType::LOCAL:
+ return "__local";
+ }
+ return "";
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_GPU_OBJECT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_GPU_OBJECT_H__
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "ClContext.h"
+#include "OpenclWrapper.h"
+#include "AccessType.h"
+#include "DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct GPUImage2DDescriptor
+{
+ DataType data_type = DataType::UNKNOWN;
+ AccessType access_type = AccessType::UNKNOWN;
+ cl_mem memory = nullptr;
+};
+
+struct GPUImage3DDescriptor
+{
+ DataType data_type = DataType::UNKNOWN;
+ AccessType access_type = AccessType::UNKNOWN;
+ cl_mem memory = nullptr;
+};
+
+struct GPUImage2DArrayDescriptor
+{
+ DataType data_type = DataType::UNKNOWN;
+ AccessType access_type = AccessType::UNKNOWN;
+ cl_mem memory = nullptr;
+};
+
+struct GPUImageBufferDescriptor
+{
+ DataType data_type = DataType::UNKNOWN;
+ AccessType access_type = AccessType::UNKNOWN;
+ cl_mem memory = nullptr;
+};
+
+struct GPUCustomMemoryDescriptor
+{
+ std::string type_name = "";
+ cl_mem memory = nullptr;
+};
+
+enum class MemoryType
+{
+ GLOBAL,
+ CONSTANT,
+ LOCAL
+};
+
+std::string MemoryTypeToCLType(MemoryType type);
+
+struct GPUBufferDescriptor
+{
+ DataType data_type = DataType::UNKNOWN;
+ AccessType access_type = AccessType::UNKNOWN;
+ int element_size = 0;
+ MemoryType memory_type = MemoryType::GLOBAL;
+ std::vector<std::string> attributes;
+ cl_mem memory = nullptr;
+};
+
+struct GPUResources
+{
+ std::vector<std::string> ints;
+ std::vector<std::string> floats;
+ std::vector<std::pair<std::string, GPUBufferDescriptor>> buffers;
+ std::vector<std::pair<std::string, GPUImage2DDescriptor>> images2d;
+ std::vector<std::pair<std::string, GPUImage2DArrayDescriptor>> image2d_arrays;
+ std::vector<std::pair<std::string, GPUImage3DDescriptor>> images3d;
+ std::vector<std::pair<std::string, GPUImageBufferDescriptor>> image_buffers;
+ std::vector<std::pair<std::string, GPUCustomMemoryDescriptor>> custom_memories;
+
+ std::vector<std::string> GetNames() const
+ {
+ std::vector<std::string> names = ints;
+ names.insert(names.end(), floats.begin(), floats.end());
+ for (const auto &obj : buffers)
+ {
+ names.push_back(obj.first);
+ }
+ for (const auto &obj : images2d)
+ {
+ names.push_back(obj.first);
+ }
+ for (const auto &obj : image2d_arrays)
+ {
+ names.push_back(obj.first);
+ }
+ for (const auto &obj : images3d)
+ {
+ names.push_back(obj.first);
+ }
+ for (const auto &obj : image_buffers)
+ {
+ names.push_back(obj.first);
+ }
+ for (const auto &obj : custom_memories)
+ {
+ names.push_back(obj.first);
+ }
+ return names;
+ }
+};
+
+struct GPUResourcesWithValue
+{
+ std::vector<std::pair<std::string, int>> ints;
+ std::vector<std::pair<std::string, float>> floats;
+ std::vector<std::pair<std::string, cl_mem>> buffers;
+ std::vector<std::pair<std::string, cl_mem>> images2d;
+ std::vector<std::pair<std::string, cl_mem>> image2d_arrays;
+ std::vector<std::pair<std::string, cl_mem>> images3d;
+ std::vector<std::pair<std::string, cl_mem>> image_buffers;
+ std::vector<std::pair<std::string, cl_mem>> custom_memories;
+};
+
+class GPUObject;
+
+class GPUObjectDescriptor
+{
+public:
+ GPUObjectDescriptor() = default;
+ GPUObjectDescriptor(const GPUObjectDescriptor &) = default;
+ GPUObjectDescriptor &operator=(const GPUObjectDescriptor &) = default;
+ GPUObjectDescriptor(GPUObjectDescriptor &&obj_desc) : state_vars_(std::move(obj_desc.state_vars_))
+ {
+ }
+ GPUObjectDescriptor &operator=(GPUObjectDescriptor &&obj_desc)
+ {
+ if (this != &obj_desc)
+ {
+ state_vars_ = std::move(obj_desc.state_vars_);
+ }
+ return *this;
+ }
+ virtual ~GPUObjectDescriptor() = default;
+
+ void SetStateVar(const std::string &key, const std::string &value) const
+ {
+ state_vars_[key] = value;
+ }
+
+ virtual std::string PerformConstExpr(const std::string &) const { return ""; }
+
+ virtual absl::Status PerformSelector(const std::string &, const std::vector<std::string> &,
+ const std::vector<std::string> &, std::string *result) const
+ {
+ *result = "";
+ return absl::OkStatus();
+ }
+ virtual GPUResources GetGPUResources() const { return GPUResources(); }
+
+ virtual absl::Status CreateGPUObject(CLContext *, std::unique_ptr<GPUObject> *) const
+ {
+ return absl::OkStatus();
+ }
+ virtual void Release() {}
+
+ void SetAccess(AccessType access_type) { access_type_ = access_type; }
+ AccessType GetAccess() const { return access_type_; }
+
+protected:
+ // friend flatbuffers::Offset<data::GPUObjectDescriptor> Encode(
+ // const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder);
+ // friend void Decode(const data::GPUObjectDescriptor* fb_obj,
+ // GPUObjectDescriptor* obj);
+ mutable std::map<std::string, std::string> state_vars_;
+ AccessType access_type_ = AccessType::UNKNOWN;
+};
+
+using GPUObjectDescriptorPtr = std::unique_ptr<GPUObjectDescriptor>;
+
+class GPUObject
+{
+public:
+ GPUObject() = default;
+ // Move only
+ GPUObject(GPUObject &&obj_desc) = default;
+ GPUObject &operator=(GPUObject &&obj_desc) = default;
+ GPUObject(const GPUObject &) = delete;
+ GPUObject &operator=(const GPUObject &) = delete;
+ virtual ~GPUObject() = default;
+ virtual absl::Status GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const = 0;
+};
+
+using GPUObjectPtr = std::unique_ptr<GPUObject>;
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_GPU_OBJECT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InferenceContext.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+#include "Buffer.h"
+#include "ClDevice.h"
+
+#include "kernels/GpuOperation.h"
+#include "ModelHints.h"
+#include "Precision.h"
+#include "StorageTypeUtil.h"
+#include "TensorType.h"
+#include "DataType.h"
+#include "Model.h"
+#include "Operations.h"
+#include "Shape.h"
+#include "Types.h"
+#include "Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+CLNode::CLNode(CLNode &&node)
+ : operation(std::move(node.operation)), inputs(std::move(node.inputs)),
+ outputs(std::move(node.outputs)), name(std::move(node.name))
+{
+}
+
+CLNode &CLNode::operator=(CLNode &&node)
+{
+ if (this != &node)
+ {
+ operation = std::move(node.operation);
+ inputs = std::move(node.inputs);
+ outputs = std::move(node.outputs);
+ name = std::move(node.name);
+ }
+ return *this;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_INFERENCE_CONTEXT_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_INFERENCE_CONTEXT_H__
+
+#include <cstdint>
+#include <functional>
+#include <map>
+#include <memory>
+#include <vector>
+#include <unordered_map>
+
+#include "Buffer.h"
+#include "ClCommandQueue.h"
+#include "Environment.h"
+#include "GpuObject.h"
+#include "kernels/GpuOperation.h"
+#include "ModelHints.h"
+#include "OpenclWrapper.h"
+#include "Precision.h"
+#include "TensorType.h"
+#include "Model.h"
+#include "InternalTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct CLNode
+{
+ std::unique_ptr<GPUOperation> operation;
+ std::vector<ValueId> inputs;
+ std::vector<ValueId> outputs;
+
+ // Mostly for debug purposes.
+ std::string name;
+
+ CLNode() = default;
+
+ CLNode(CLNode &&node);
+ CLNode &operator=(CLNode &&node);
+ CLNode(const CLNode &) = delete;
+ CLNode &operator=(const CLNode &) = delete;
+};
+
+class InferenceContext
+{
+public:
+ struct CreateInferenceInfo
+ {
+ CalculationsPrecision precision;
+ TensorStorageType storage_type;
+ ModelHints hints;
+ };
+
+ struct DummyTensor
+ {
+ BHWC shape;
+ TensorDescriptor descriptor;
+
+ bool operator==(const DummyTensor &b) const
+ {
+ return shape == b.shape && descriptor == b.descriptor;
+ }
+ };
+
+ class TensorReserver
+ {
+ public:
+ ValueId Add(const std::shared_ptr<DummyTensor> dummy)
+ {
+ reservations_[next_] = std::move(dummy);
+ return next_++;
+ }
+ void Add(ValueId id, const std::shared_ptr<DummyTensor> dummy)
+ {
+ reservations_[id] = std::move(dummy);
+ }
+ void SetNext(ValueId id) { next_ = id; }
+ bool HaveTensor(ValueId id) { return reservations_.find(id) != reservations_.end(); }
+ std::shared_ptr<DummyTensor> Get(ValueId id) { return reservations_[id]; }
+
+ std::vector<std::pair<ValueId, TensorDescriptor>> GetTensorDescs() const
+ {
+ std::vector<std::pair<ValueId, TensorDescriptor>> result;
+ for (auto &v : reservations_)
+ {
+ TensorDescriptor desc = v.second->descriptor;
+ desc.shape.b = v.second->shape.b;
+ desc.shape.h = v.second->shape.h;
+ desc.shape.w = v.second->shape.w;
+ desc.shape.d = 1;
+ desc.shape.c = v.second->shape.c;
+ result.push_back({v.first, desc});
+ }
+ return result;
+ }
+
+ void Add(const std::vector<std::pair<ValueId, TensorDescriptor>> &tensors)
+ {
+ for (auto &v : tensors)
+ {
+ auto dummy = std::make_shared<DummyTensor>();
+ dummy->descriptor = v.second;
+ dummy->shape.b = v.second.shape.b;
+ dummy->shape.h = v.second.shape.h;
+ dummy->shape.w = v.second.shape.w;
+ dummy->shape.c = v.second.shape.c;
+ Add(v.first, dummy);
+ }
+ }
+
+ private:
+ std::unordered_map<ValueId, std::shared_ptr<DummyTensor>> reservations_;
+ ValueId next_ = 0;
+ };
+
+private:
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_INFERENCE_CONTEXT_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_INTERNAL_TENSOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_INTERNAL_TENSOR_H__
+
+#include <stdint.h>
+
+#include <vector>
+
+#include "DataType.h"
+#include "Shape.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace internal_tensor
+{
+
+// Meta function given element type returns a type for Tensor data container.
+template <DataType Type> struct StorageType;
+
+template <> struct StorageType<DataType::FLOAT32>
+{
+ using value = std::vector<float>;
+};
+
+template <> struct StorageType<DataType::INT32>
+{
+ using value = std::vector<int32_t>;
+};
+
+} // namespace internal_tensor
+
+template <typename ShapeT, DataType Type> struct InternalTensor
+{
+ using ShapeType = ShapeT;
+
+ constexpr static DataType kType = Type;
+
+ using TensorStorageType = typename internal_tensor::StorageType<Type>::value;
+
+ // Opaque id of a tensor.
+ int64_t id = -1;
+
+ ShapeType shape;
+
+ TensorStorageType data;
+};
+
+// TensorRef is a reference to another tensor. If an object should never hold
+// tensor data, then TensorRef should be used instead.
+template <typename ShapeT> struct TensorRef
+{
+ using ShapeType = ShapeT;
+
+ DataType type = DataType::UNKNOWN;
+
+ ShapeT shape;
+
+ // Opaque reference to a tensor. Upstream component is responsible for
+ // resolving this reference into an actual tensor.
+ int64_t ref = -1;
+
+ // Specifies if the tensor should be a variable input tensor that must be an
+ // output as well as an input to the graph.
+ bool is_variable_input = false;
+};
+
+template <typename ShapeT, DataType Type> constexpr DataType InternalTensor<ShapeT, Type>::kType;
+
+template <typename ShapeT, DataType Type>
+InternalTensor<ShapeT, Type> MakeZeroTensor(const ShapeT &shape)
+{
+ InternalTensor<ShapeT, Type> tensor;
+ tensor.shape = shape;
+ tensor.data =
+ typename InternalTensor<ShapeT, Type>::TensorStorageType(shape.DimensionsProduct(), 0);
+ return tensor;
+}
+
+using TensorFloat32 = InternalTensor<BHWC, DataType::FLOAT32>;
+using Tensor5DFloat32 = InternalTensor<BHWDC, DataType::FLOAT32>;
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_INTERNAL_TENSOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LinearStorage.h"
+
+#include "absl/strings/str_cat.h"
+#include "DataType.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+TensorLinearDescriptor::TensorLinearDescriptor(TensorLinearDescriptor &&desc)
+ : GPUObjectDescriptor(std::move(desc)), storage_type(desc.storage_type),
+ element_type(desc.element_type), memory_type(desc.memory_type), size(desc.size),
+ data(std::move(desc.data))
+{
+}
+
+TensorLinearDescriptor &TensorLinearDescriptor::operator=(TensorLinearDescriptor &&desc)
+{
+ if (this != &desc)
+ {
+ std::swap(storage_type, desc.storage_type);
+ std::swap(element_type, desc.element_type);
+ std::swap(memory_type, desc.memory_type);
+ std::swap(size, desc.size);
+ data = std::move(desc.data);
+ GPUObjectDescriptor::operator=(std::move(desc));
+ }
+ return *this;
+}
+
+void TensorLinearDescriptor::Release() { data.clear(); }
+
+GPUResources TensorLinearDescriptor::GetGPUResources() const
+{
+ GPUResources resources;
+ resources.ints.push_back("length");
+ if (storage_type == LinearStorageType::BUFFER)
+ {
+ GPUBufferDescriptor desc;
+ desc.data_type = element_type;
+ desc.access_type = access_type_;
+ desc.element_size = 4;
+ desc.memory_type = memory_type;
+ resources.buffers.push_back({"buffer", desc});
+ }
+ else
+ {
+ GPUImage2DDescriptor desc;
+ desc.data_type = element_type;
+ desc.access_type = access_type_;
+ resources.images2d.push_back({"tex2d", desc});
+ }
+ return resources;
+}
+
+absl::Status TensorLinearDescriptor::PerformSelector(const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &,
+ std::string *result) const
+{
+ if (selector == "Length")
+ {
+ *result = "length";
+ return absl::OkStatus();
+ }
+ else if (selector == "Read")
+ {
+ return PerformReadSelector(args, result);
+ }
+ else if (selector == "GetPtr")
+ {
+ if (storage_type != LinearStorageType::BUFFER)
+ {
+ return absl::InvalidArgumentError(
+ "GetPtr selector supported for LinearStorageType::BUFFER only.");
+ }
+ *result = "buffer";
+ return absl::OkStatus();
+ }
+ else
+ {
+ return absl::NotFoundError(
+ absl::StrCat("TensorLinearDescriptor don't have selector with name - ", selector));
+ }
+}
+
+absl::Status TensorLinearDescriptor::PerformReadSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (args.size() != 1)
+ {
+ return absl::NotFoundError(absl::StrCat(
+ "TensorLinearDescriptor Read require one argument, but ", args.size(), " was passed"));
+ }
+ if (storage_type == LinearStorageType::BUFFER)
+ {
+ *result = absl::StrCat("buffer[", args[0], "]");
+ return absl::OkStatus();
+ }
+ else
+ {
+ const std::string read = element_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
+ *result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", 0))");
+ return absl::OkStatus();
+ }
+}
+
+absl::Status TensorLinearDescriptor::CreateGPUObject(CLContext *context, GPUObjectPtr *result) const
+{
+ LinearStorage gpu_storage;
+ RETURN_IF_ERROR(gpu_storage.CreateFromTensorLinearDescriptor(*this, context));
+ *result = absl::make_unique<LinearStorage>(std::move(gpu_storage));
+ return absl::OkStatus();
+}
+
+void TensorLinearDescriptor::UploadLinearData(const InternalTensor<Linear, DataType::FLOAT32> &src,
+ int aligned_size)
+{
+ size = aligned_size == 0 ? DivideRoundUp(src.shape.v, 4) : aligned_size;
+ if (element_type == DataType::FLOAT32)
+ {
+ data.resize(size * sizeof(float) * 4);
+ float *gpu_data = reinterpret_cast<float *>(data.data());
+ for (int i = 0; i < size * 4; ++i)
+ {
+ if (i < src.shape.v)
+ {
+ gpu_data[i] = src.data[i];
+ }
+ else
+ {
+ gpu_data[i] = 0.0f;
+ }
+ }
+ }
+ // TODO
+ // It doesn't support F16 yet. I will try to add it later.
+ //
+ // else {
+ // data.resize(size * sizeof(half) * 4);
+ // half* gpu_data = reinterpret_cast<half*>(data.data());
+ // for (int i = 0; i < size * 4; ++i) {
+ // if (i < src.shape.v) {
+ // gpu_data[i] = src.data[i];
+ // } else {
+ // gpu_data[i] = 0.0f;
+ // }
+ // }
+ // }
+}
+
+void LinearStorage::Release()
+{
+ if (memory_)
+ {
+ clReleaseMemObject(memory_);
+ memory_ = nullptr;
+ }
+}
+
+LinearStorage::LinearStorage(LinearStorage &&storage)
+ : GPUObject(std::move(storage)), memory_(storage.memory_), depth_(storage.depth_),
+ storage_type_(storage.storage_type_)
+{
+ storage.memory_ = nullptr;
+}
+
+LinearStorage &LinearStorage::operator=(LinearStorage &&storage)
+{
+ if (this != &storage)
+ {
+ Release();
+ std::swap(memory_, storage.memory_);
+ std::swap(depth_, storage.depth_);
+ std::swap(storage_type_, storage.storage_type_);
+ GPUObject::operator=(std::move(storage));
+ }
+ return *this;
+}
+
+absl::Status LinearStorage::GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const
+{
+ const auto *linear_desc = dynamic_cast<const TensorLinearDescriptor *>(obj_ptr);
+ if (!linear_desc)
+ {
+ return absl::InvalidArgumentError("Expected TensorLinearDescriptor on input.");
+ }
+
+ resources->ints.push_back({"length", depth_});
+
+ if (storage_type_ == LinearStorageType::BUFFER)
+ {
+ resources->buffers.push_back({"buffer", memory_});
+ }
+ else
+ {
+ resources->images2d.push_back({"tex2d", memory_});
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status LinearStorage::CreateFromTensorLinearDescriptor(const TensorLinearDescriptor &desc,
+ CLContext *context)
+{
+ storage_type_ = desc.storage_type;
+ depth_ = desc.size;
+ uint8_t *data_ptr = desc.data.empty() ? nullptr : const_cast<unsigned char *>(desc.data.data());
+ if (storage_type_ == LinearStorageType::BUFFER)
+ {
+ bool read_only = desc.memory_type == MemoryType::CONSTANT;
+ uint8_t *data_ptr = desc.data.empty() ? nullptr : const_cast<unsigned char *>(desc.data.data());
+ // TODO
+ // It doesn't support F16 yet. I will try to add it later.
+ //
+ // const int float4_size = desc.element_type == DataType::FLOAT32
+ // ? sizeof(float) * 4
+ // : sizeof(half) * 4;
+ const int float4_size = sizeof(float) * 4;
+ return CreateCLBuffer(context->context(), depth_ * float4_size, read_only, data_ptr, &memory_);
+ }
+ else
+ {
+ return CreateRGBAImage2D(context->context(), depth_, 1,
+ DataTypeToChannelType(desc.element_type), data_ptr, &memory_);
+ }
+}
+
+LinearStorageType DeduceLinearStorageType(TensorStorageType tensor_storage_type)
+{
+ if (tensor_storage_type == TensorStorageType::BUFFER)
+ {
+ return LinearStorageType::BUFFER;
+ }
+ else
+ {
+ return LinearStorageType::TEXTURE_2D;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_LINEAR_STORAGE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_LINEAR_STORAGE_H__
+
+#include <string>
+#include <utility>
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "GpuObject.h"
+#include "OpenclWrapper.h"
+#include "TensorType.h"
+#include "Util.h"
+#include "DataType.h"
+#include "Status.h"
+#include "Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class LinearStorageType
+{
+ BUFFER,
+ TEXTURE_2D
+};
+
+struct TensorLinearDescriptor : public GPUObjectDescriptor
+{
+ LinearStorageType storage_type;
+ DataType element_type; // FLOAT32 or FLOAT16
+ MemoryType memory_type = MemoryType::GLOBAL; // applicable for BUFFER
+
+ // optional
+ int size = 0;
+ std::vector<uint8_t> data;
+
+ TensorLinearDescriptor() = default;
+ TensorLinearDescriptor(const TensorLinearDescriptor &) = default;
+ TensorLinearDescriptor &operator=(const TensorLinearDescriptor &) = default;
+ TensorLinearDescriptor(TensorLinearDescriptor &&desc);
+ TensorLinearDescriptor &operator=(TensorLinearDescriptor &&desc);
+
+ void UploadLinearData(const InternalTensor<Linear, DataType::FLOAT32> &src, int aligned_size = 0);
+
+ absl::Status PerformSelector(const std::string &selector, const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const override;
+
+ GPUResources GetGPUResources() const override;
+ absl::Status PerformReadSelector(const std::vector<std::string> &args, std::string *result) const;
+
+ absl::Status CreateGPUObject(CLContext *context, GPUObjectPtr *result) const override;
+ void Release() override;
+};
+
+LinearStorageType DeduceLinearStorageType(TensorStorageType tensor_storage_type);
+
+// Represent GPU 1D-array of FLT4(float4/half4) values
+// Can use inside texture2d or buffer
+class LinearStorage : public GPUObject
+{
+public:
+ LinearStorage() {}
+ ~LinearStorage() override { Release(); }
+
+ // Move only
+ LinearStorage(LinearStorage &&storage);
+ LinearStorage &operator=(LinearStorage &&storage);
+ LinearStorage(const LinearStorage &) = delete;
+ LinearStorage &operator=(const LinearStorage &) = delete;
+
+ absl::Status GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const override;
+
+ absl::Status CreateFromTensorLinearDescriptor(const TensorLinearDescriptor &desc,
+ CLContext *context);
+
+private:
+ void Release();
+
+ cl_mem memory_ = nullptr;
+ int depth_;
+ LinearStorageType storage_type_;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_LINEAR_STORAGE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_H__
+
+#include <string>
+
+#include "absl/types/any.h"
+#include "InternalTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// There is yet another representation of CNN graph. The primary purpose of this
+// representation is to simplify graph manipulation.
+
+using ValueId = uint32_t;
+
+// Used to emulate quantized behavior.
+struct QuantizationParams
+{
+ float min = 0;
+ float max = 0;
+ float scale = 0;
+};
+
+struct Operation
+{
+ std::string type;
+ absl::any attributes;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_HINTS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_HINTS_H__
+
+#include <cstdint>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct ModelHints
+{
+ using ModelHint = uint64_t;
+
+ // By default we want the fastest inference.
+ static constexpr ModelHint kFastestInference = 0x00000000;
+ // Can improve compilation time, but inference can be slower.
+ static constexpr ModelHint kReduceKernelsCount = 0x00000001;
+ // Can improve tuning time, but inference can be slower.
+ static constexpr ModelHint kFastTuning = 0x00000002;
+
+ // Experimental.
+ // Can improve performance and memory consumption, but slow down
+ // initialization a lot and create more kernels.
+ static constexpr ModelHint kAllowSpecialKernels = 0x00000004;
+
+ void Add(ModelHint hint)
+ {
+ if (hint == kFastestInference)
+ {
+ hints = kFastestInference;
+ }
+ else
+ {
+ hints |= hint;
+ }
+ }
+
+ bool Check(ModelHint hint) const { return hints & hint; }
+
+ uint64_t hints = kFastestInference;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_MODEL_HINTS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#if defined(_WIN32)
+#define __WINDOWS__
+#endif
+
+#include "OpenclWrapper.h"
+
+#ifdef __WINDOWS__
+#include <windows.h>
+#else
+#include <dlfcn.h>
+#endif
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+#ifdef __ANDROID__
+#define LoadFunction(function) \
+ if (use_wrapper) \
+ { \
+ function = reinterpret_cast<PFN_##function>(loadOpenCLPointer(#function)); \
+ } \
+ else \
+ { \
+ function = reinterpret_cast<PFN_##function>(dlsym(*libopencl, #function)); \
+ }
+#elif defined(__WINDOWS__)
+#define LoadFunction(function) \
+ function = reinterpret_cast<PFN_##function>(GetProcAddress(libopencl, #function));
+#else
+#define LoadFunction(function) \
+ function = reinterpret_cast<PFN_##function>(dlsym(*libopencl, #function));
+#endif
+
+#ifdef __WINDOWS__
+void LoadOpenCLFunctions(HMODULE libopencl);
+#else
+void LoadOpenCLFunctions(void **libopencl, bool use_wrapper);
+#endif
+
+absl::Status LoadOpenCL(void **libopencl)
+{
+#ifdef __WINDOWS__
+ HMODULE libopencl = LoadLibraryA("OpenCL.dll");
+ if (libopencl)
+ {
+ LoadOpenCLFunctions(libopencl);
+ return absl::OkStatus();
+ }
+ else
+ {
+ DWORD error_code = GetLastError();
+ return absl::UnknownError(
+ absl::StrCat("Can not open OpenCL library on this device, error code - ", error_code));
+ }
+#else
+ *libopencl = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL);
+ if (*libopencl)
+ {
+ LoadOpenCLFunctions(libopencl, false);
+ return absl::OkStatus();
+ }
+ // record error
+ std::string error(dlerror());
+#ifdef __ANDROID__
+ // Pixel phone or auto?
+ *libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL);
+ if (!*libopencl)
+ {
+ *libopencl = dlopen("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL);
+ }
+ if (*libopencl)
+ {
+ typedef void (*enableOpenCL_t)();
+ enableOpenCL_t enableOpenCL =
+ reinterpret_cast<enableOpenCL_t>(dlsym(*libopencl, "enableOpenCL"));
+ enableOpenCL();
+ LoadOpenCLFunctions(libopencl, true);
+ return absl::OkStatus();
+ }
+#endif
+ return absl::UnknownError(absl::StrCat("Can not open OpenCL library on this device - ", error));
+#endif
+}
+
+void UnloadOpenCL(void *libopencl)
+{
+ if (libopencl)
+ {
+ dlclose(libopencl);
+ }
+}
+
+#ifdef __WINDOWS__
+void LoadOpenCLFunctions(HMODULE libopencl)
+{
+#else
+#ifdef __ANDROID__
+void LoadOpenCLFunctions(void **libopencl, bool use_wrapper)
+{
+ typedef void *(*loadOpenCLPointer_t)(const char *name);
+ loadOpenCLPointer_t loadOpenCLPointer;
+ if (use_wrapper)
+ {
+ loadOpenCLPointer =
+ reinterpret_cast<loadOpenCLPointer_t>(dlsym(*libopencl, "loadOpenCLPointer"));
+ }
+#else
+void LoadOpenCLFunctions(void **libopencl, bool)
+{
+#endif // __ANDROID__
+#endif // __WINDOWS__
+
+ LoadFunction(clGetPlatformIDs);
+ LoadFunction(clGetPlatformInfo);
+ LoadFunction(clGetDeviceIDs);
+ LoadFunction(clGetDeviceInfo);
+ LoadFunction(clCreateSubDevices);
+ LoadFunction(clRetainDevice);
+ LoadFunction(clReleaseDevice);
+ LoadFunction(clCreateContext);
+ LoadFunction(clCreateContextFromType);
+ LoadFunction(clRetainContext);
+ LoadFunction(clReleaseContext);
+ LoadFunction(clGetContextInfo);
+ LoadFunction(clCreateCommandQueueWithProperties);
+ LoadFunction(clRetainCommandQueue);
+ LoadFunction(clReleaseCommandQueue);
+ LoadFunction(clGetCommandQueueInfo);
+ LoadFunction(clCreateBuffer);
+ LoadFunction(clCreateSubBuffer);
+ LoadFunction(clCreateImage);
+ LoadFunction(clCreatePipe);
+ LoadFunction(clRetainMemObject);
+ LoadFunction(clReleaseMemObject);
+ LoadFunction(clGetSupportedImageFormats);
+ LoadFunction(clGetMemObjectInfo);
+ LoadFunction(clGetImageInfo);
+ LoadFunction(clGetPipeInfo);
+ LoadFunction(clSetMemObjectDestructorCallback);
+ LoadFunction(clSVMAlloc);
+ LoadFunction(clSVMFree);
+ LoadFunction(clCreateSamplerWithProperties);
+ LoadFunction(clRetainSampler);
+ LoadFunction(clReleaseSampler);
+ LoadFunction(clGetSamplerInfo);
+ LoadFunction(clCreateProgramWithSource);
+ LoadFunction(clCreateProgramWithBinary);
+ LoadFunction(clCreateProgramWithBuiltInKernels);
+ LoadFunction(clRetainProgram);
+ LoadFunction(clReleaseProgram);
+ LoadFunction(clBuildProgram);
+ LoadFunction(clCompileProgram);
+ LoadFunction(clLinkProgram);
+ LoadFunction(clUnloadPlatformCompiler);
+ LoadFunction(clGetProgramInfo);
+ LoadFunction(clGetProgramBuildInfo);
+ LoadFunction(clCreateKernel);
+ LoadFunction(clCreateKernelsInProgram);
+ LoadFunction(clRetainKernel);
+ LoadFunction(clReleaseKernel);
+ LoadFunction(clSetKernelArg);
+ LoadFunction(clSetKernelArgSVMPointer);
+ LoadFunction(clSetKernelExecInfo);
+ LoadFunction(clGetKernelInfo);
+ LoadFunction(clGetKernelArgInfo);
+ LoadFunction(clGetKernelWorkGroupInfo);
+ LoadFunction(clWaitForEvents);
+ LoadFunction(clGetEventInfo);
+ LoadFunction(clCreateUserEvent);
+ LoadFunction(clRetainEvent);
+ LoadFunction(clReleaseEvent);
+ LoadFunction(clSetUserEventStatus);
+ LoadFunction(clSetEventCallback);
+ LoadFunction(clGetEventProfilingInfo);
+ LoadFunction(clFlush);
+ LoadFunction(clFinish);
+ LoadFunction(clEnqueueReadBuffer);
+ LoadFunction(clEnqueueReadBufferRect);
+ LoadFunction(clEnqueueWriteBuffer);
+ LoadFunction(clEnqueueWriteBufferRect);
+ LoadFunction(clEnqueueFillBuffer);
+ LoadFunction(clEnqueueCopyBuffer);
+ LoadFunction(clEnqueueCopyBufferRect);
+ LoadFunction(clEnqueueReadImage);
+ LoadFunction(clEnqueueWriteImage);
+ LoadFunction(clEnqueueFillImage);
+ LoadFunction(clEnqueueCopyImage);
+ LoadFunction(clEnqueueCopyImageToBuffer);
+ LoadFunction(clEnqueueCopyBufferToImage);
+ LoadFunction(clEnqueueMapBuffer);
+ LoadFunction(clEnqueueMapImage);
+ LoadFunction(clEnqueueUnmapMemObject);
+ LoadFunction(clEnqueueMigrateMemObjects);
+ LoadFunction(clEnqueueNDRangeKernel);
+ LoadFunction(clEnqueueNativeKernel);
+ LoadFunction(clEnqueueMarkerWithWaitList);
+ LoadFunction(clEnqueueBarrierWithWaitList);
+ LoadFunction(clEnqueueSVMFree);
+ LoadFunction(clEnqueueSVMMemcpy);
+ LoadFunction(clEnqueueSVMMemFill);
+ LoadFunction(clEnqueueSVMMap);
+ LoadFunction(clEnqueueSVMUnmap);
+ LoadFunction(clGetExtensionFunctionAddressForPlatform);
+ LoadFunction(clCreateImage2D);
+ LoadFunction(clCreateImage3D);
+ LoadFunction(clEnqueueMarker);
+ LoadFunction(clEnqueueWaitForEvents);
+ LoadFunction(clEnqueueBarrier);
+ LoadFunction(clUnloadCompiler);
+ LoadFunction(clGetExtensionFunctionAddress);
+ LoadFunction(clCreateCommandQueue);
+ LoadFunction(clCreateSampler);
+ LoadFunction(clEnqueueTask);
+
+ // OpenGL sharing
+ LoadFunction(clCreateFromGLBuffer);
+ LoadFunction(clCreateFromGLTexture);
+ LoadFunction(clEnqueueAcquireGLObjects);
+ LoadFunction(clEnqueueReleaseGLObjects);
+
+ // cl_khr_egl_event extension
+ LoadFunction(clCreateEventFromEGLSyncKHR);
+
+ // EGL sharing
+ LoadFunction(clCreateFromEGLImageKHR);
+ LoadFunction(clEnqueueAcquireEGLObjectsKHR);
+ LoadFunction(clEnqueueReleaseEGLObjectsKHR);
+} // namespace gpu_cl
+
+// No OpenCL support, do not set function addresses
+PFN_clGetPlatformIDs clGetPlatformIDs;
+PFN_clGetPlatformInfo clGetPlatformInfo;
+PFN_clGetDeviceIDs clGetDeviceIDs;
+PFN_clGetDeviceInfo clGetDeviceInfo;
+PFN_clCreateSubDevices clCreateSubDevices;
+PFN_clRetainDevice clRetainDevice;
+PFN_clReleaseDevice clReleaseDevice;
+PFN_clCreateContext clCreateContext;
+PFN_clCreateContextFromType clCreateContextFromType;
+PFN_clRetainContext clRetainContext;
+PFN_clReleaseContext clReleaseContext;
+PFN_clGetContextInfo clGetContextInfo;
+PFN_clCreateCommandQueueWithProperties clCreateCommandQueueWithProperties;
+PFN_clRetainCommandQueue clRetainCommandQueue;
+PFN_clReleaseCommandQueue clReleaseCommandQueue;
+PFN_clGetCommandQueueInfo clGetCommandQueueInfo;
+PFN_clCreateBuffer clCreateBuffer;
+PFN_clCreateSubBuffer clCreateSubBuffer;
+PFN_clCreateImage clCreateImage;
+PFN_clCreatePipe clCreatePipe;
+PFN_clRetainMemObject clRetainMemObject;
+PFN_clReleaseMemObject clReleaseMemObject;
+PFN_clGetSupportedImageFormats clGetSupportedImageFormats;
+PFN_clGetMemObjectInfo clGetMemObjectInfo;
+PFN_clGetImageInfo clGetImageInfo;
+PFN_clGetPipeInfo clGetPipeInfo;
+PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback;
+PFN_clSVMAlloc clSVMAlloc;
+PFN_clSVMFree clSVMFree;
+PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties;
+PFN_clRetainSampler clRetainSampler;
+PFN_clReleaseSampler clReleaseSampler;
+PFN_clGetSamplerInfo clGetSamplerInfo;
+PFN_clCreateProgramWithSource clCreateProgramWithSource;
+PFN_clCreateProgramWithBinary clCreateProgramWithBinary;
+PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels;
+PFN_clRetainProgram clRetainProgram;
+PFN_clReleaseProgram clReleaseProgram;
+PFN_clBuildProgram clBuildProgram;
+PFN_clCompileProgram clCompileProgram;
+PFN_clLinkProgram clLinkProgram;
+PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler;
+PFN_clGetProgramInfo clGetProgramInfo;
+PFN_clGetProgramBuildInfo clGetProgramBuildInfo;
+PFN_clCreateKernel clCreateKernel;
+PFN_clCreateKernelsInProgram clCreateKernelsInProgram;
+PFN_clRetainKernel clRetainKernel;
+PFN_clReleaseKernel clReleaseKernel;
+PFN_clSetKernelArg clSetKernelArg;
+PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer;
+PFN_clSetKernelExecInfo clSetKernelExecInfo;
+PFN_clGetKernelInfo clGetKernelInfo;
+PFN_clGetKernelArgInfo clGetKernelArgInfo;
+PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo;
+PFN_clWaitForEvents clWaitForEvents;
+PFN_clGetEventInfo clGetEventInfo;
+PFN_clCreateUserEvent clCreateUserEvent;
+PFN_clRetainEvent clRetainEvent;
+PFN_clReleaseEvent clReleaseEvent;
+PFN_clSetUserEventStatus clSetUserEventStatus;
+PFN_clSetEventCallback clSetEventCallback;
+PFN_clGetEventProfilingInfo clGetEventProfilingInfo;
+PFN_clFlush clFlush;
+PFN_clFinish clFinish;
+PFN_clEnqueueReadBuffer clEnqueueReadBuffer;
+PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect;
+PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer;
+PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect;
+PFN_clEnqueueFillBuffer clEnqueueFillBuffer;
+PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer;
+PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect;
+PFN_clEnqueueReadImage clEnqueueReadImage;
+PFN_clEnqueueWriteImage clEnqueueWriteImage;
+PFN_clEnqueueFillImage clEnqueueFillImage;
+PFN_clEnqueueCopyImage clEnqueueCopyImage;
+PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer;
+PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage;
+PFN_clEnqueueMapBuffer clEnqueueMapBuffer;
+PFN_clEnqueueMapImage clEnqueueMapImage;
+PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject;
+PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects;
+PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel;
+PFN_clEnqueueNativeKernel clEnqueueNativeKernel;
+PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList;
+PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList;
+PFN_clEnqueueSVMFree clEnqueueSVMFree;
+PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy;
+PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill;
+PFN_clEnqueueSVMMap clEnqueueSVMMap;
+PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap;
+PFN_clGetExtensionFunctionAddressForPlatform clGetExtensionFunctionAddressForPlatform;
+PFN_clCreateImage2D clCreateImage2D;
+PFN_clCreateImage3D clCreateImage3D;
+PFN_clEnqueueMarker clEnqueueMarker;
+PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents;
+PFN_clEnqueueBarrier clEnqueueBarrier;
+PFN_clUnloadCompiler clUnloadCompiler;
+PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress;
+PFN_clCreateCommandQueue clCreateCommandQueue;
+PFN_clCreateSampler clCreateSampler;
+PFN_clEnqueueTask clEnqueueTask;
+
+// OpenGL sharing
+PFN_clCreateFromGLBuffer clCreateFromGLBuffer;
+PFN_clCreateFromGLTexture clCreateFromGLTexture;
+PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects;
+PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects;
+
+// cl_khr_egl_event extension
+PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR;
+
+// EGL sharing
+PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR;
+PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR;
+PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR;
+
+cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format *image_format, const cl_image_desc *image_desc,
+ void *host_ptr, cl_int *errcode_ret)
+{
+ if (clCreateImage)
+ { // clCreateImage available since OpenCL 1.2
+ return clCreateImage(context, flags, image_format, image_desc, host_ptr, errcode_ret);
+ }
+ else
+ {
+ return clCreateImage2D(context, flags, image_format, image_desc->image_width,
+ image_desc->image_height, image_desc->image_row_pitch, host_ptr,
+ errcode_ret);
+ }
+}
+
+cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format *image_format, const cl_image_desc *image_desc,
+ void *host_ptr, cl_int *errcode_ret)
+{
+ if (clCreateImage)
+ { // clCreateImage available since OpenCL 1.2
+ return clCreateImage(context, flags, image_format, image_desc, host_ptr, errcode_ret);
+ }
+ else
+ {
+ return clCreateImage3D(context, flags, image_format, image_desc->image_width,
+ image_desc->image_height, image_desc->image_depth,
+ image_desc->image_row_pitch, image_desc->image_slice_pitch, host_ptr,
+ errcode_ret);
+ }
+}
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_WRAPPERE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_WRAPPERE_H__
+
+#include "CL/cl.h"
+#include "CL/cl_egl.h"
+#include "CL/cl_ext.h"
+#include "CL/cl_gl.h"
+#include "CL/cl_platform.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+absl::Status LoadOpenCL(void **libopencl);
+void UnloadOpenCL(void *libopencl);
+
+typedef cl_int(CL_API_CALL *PFN_clGetPlatformIDs)(
+ cl_uint /* num_entries */, cl_platform_id * /* platforms */,
+ cl_uint * /* num_platforms */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetPlatformInfo)(
+ cl_platform_id /* platform */, cl_platform_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetDeviceIDs)(
+ cl_platform_id /* platform */, cl_device_type /* device_type */, cl_uint /* num_entries */,
+ cl_device_id * /* devices */, cl_uint * /* num_devices */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetDeviceInfo)(
+ cl_device_id /* device */, cl_device_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clCreateSubDevices)(
+ cl_device_id /* in_device */, const cl_device_partition_property * /* properties */,
+ cl_uint /* num_devices */, cl_device_id * /* out_devices */,
+ cl_uint * /* num_devices_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clRetainDevice)(cl_device_id /* device */)
+ CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clReleaseDevice)(cl_device_id /* device */)
+ CL_API_SUFFIX__VERSION_1_2;
+typedef cl_context(CL_API_CALL *PFN_clCreateContext)(
+ const cl_context_properties * /* properties */, cl_uint /* num_devices */,
+ const cl_device_id * /* devices */,
+ void(CL_CALLBACK * /* pfn_notify */)(const char *, const void *, size_t, void *),
+ void * /* user_data */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_context(CL_API_CALL *PFN_clCreateContextFromType)(
+ const cl_context_properties * /* properties */, cl_device_type /* device_type */,
+ void(CL_CALLBACK * /* pfn_notify*/)(const char *, const void *, size_t, void *),
+ void * /* user_data */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clRetainContext)(cl_context /* context */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseContext)(cl_context /* context */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetContextInfo)(
+ cl_context /* context */, cl_context_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueueWithProperties)(
+ cl_context /* context */, cl_device_id /* device */, const cl_queue_properties * /* properties */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clRetainCommandQueue)(cl_command_queue /* command_queue */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseCommandQueue)(cl_command_queue /* command_queue */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetCommandQueueInfo)(
+ cl_command_queue /* command_queue */, cl_command_queue_info /* param_name */,
+ size_t /* param_value_size */, void * /* param_value */,
+ size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_mem(CL_API_CALL *PFN_clCreateBuffer)(
+ cl_context /* context */, cl_mem_flags /* flags */, size_t /* size */, void * /* host_ptr */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_mem(CL_API_CALL *PFN_clCreateSubBuffer)(
+ cl_mem /* buffer */, cl_mem_flags /* flags */, cl_buffer_create_type /* buffer_create_type */,
+ const void * /* buffer_create_info */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_mem(CL_API_CALL *PFN_clCreateImage)(
+ cl_context /* context */, cl_mem_flags /* flags */, const cl_image_format * /* image_format */,
+ const cl_image_desc * /* image_desc */, void * /* host_ptr */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_mem(CL_API_CALL *PFN_clCreatePipe)(
+ cl_context /* context */, cl_mem_flags /* flags */, cl_uint /* pipe_packet_size */,
+ cl_uint /* pipe_max_packets */, const cl_pipe_properties * /* properties */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clRetainMemObject)(cl_mem /* memobj */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /* memobj */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetSupportedImageFormats)(
+ cl_context /* context */, cl_mem_flags /* flags */, cl_mem_object_type /* image_type */,
+ cl_uint /* num_entries */, cl_image_format * /* image_formats */,
+ cl_uint * /* num_image_formats */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetMemObjectInfo)(
+ cl_mem /* memobj */, cl_mem_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetImageInfo)(
+ cl_mem /* image */, cl_image_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetPipeInfo)(
+ cl_mem /* pipe */, cl_pipe_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clSetMemObjectDestructorCallback)(
+ cl_mem /* memobj */,
+ void(CL_CALLBACK * /*pfn_notify*/)(cl_mem /* memobj */, void * /*user_data*/),
+ void * /*user_data */) CL_API_SUFFIX__VERSION_1_1;
+typedef void *(CL_API_CALL *PFN_clSVMAlloc)(cl_context /* context */, cl_svm_mem_flags /* flags */,
+ size_t /* size */,
+ cl_uint /* alignment */)CL_API_SUFFIX__VERSION_2_0;
+typedef void(CL_API_CALL *PFN_clSVMFree)(cl_context /* context */,
+ void * /* svm_pointer */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_sampler(CL_API_CALL *PFN_clCreateSamplerWithProperties)(
+ cl_context /* context */, const cl_sampler_properties * /* normalized_coords */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clRetainSampler)(cl_sampler /* sampler */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseSampler)(cl_sampler /* sampler */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetSamplerInfo)(
+ cl_sampler /* sampler */, cl_sampler_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithSource)(
+ cl_context /* context */, cl_uint /* count */, const char ** /* strings */,
+ const size_t * /* lengths */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBinary)(
+ cl_context /* context */, cl_uint /* num_devices */, const cl_device_id * /* device_list */,
+ const size_t * /* lengths */, const unsigned char ** /* binaries */, cl_int * /* binary_status */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_program(CL_API_CALL *PFN_clCreateProgramWithBuiltInKernels)(
+ cl_context /* context */, cl_uint /* num_devices */, const cl_device_id * /* device_list */,
+ const char * /* kernel_names */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clRetainProgram)(cl_program /* program */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseProgram)(cl_program /* program */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clBuildProgram)(
+ cl_program /* program */, cl_uint /* num_devices */, const cl_device_id * /* device_list */,
+ const char * /* options */,
+ void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */),
+ void * /* user_data */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clCompileProgram)(
+ cl_program /* program */, cl_uint /* num_devices */, const cl_device_id * /* device_list */,
+ const char * /* options */, cl_uint /* num_input_headers */,
+ const cl_program * /* input_headers */, const char ** /* header_include_names */,
+ void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */),
+ void * /* user_data */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_program(CL_API_CALL *PFN_clLinkProgram)(
+ cl_context /* context */, cl_uint /* num_devices */, const cl_device_id * /* device_list */,
+ const char * /* options */, cl_uint /* num_input_programs */,
+ const cl_program * /* input_programs */,
+ void(CL_CALLBACK * /* pfn_notify */)(cl_program /* program */, void * /* user_data */),
+ void * /* user_data */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clUnloadPlatformCompiler)(cl_platform_id /* platform */)
+ CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clGetProgramInfo)(
+ cl_program /* program */, cl_program_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetProgramBuildInfo)(
+ cl_program /* program */, cl_device_id /* device */, cl_program_build_info /* param_name */,
+ size_t /* param_value_size */, void * /* param_value */,
+ size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_kernel(CL_API_CALL *PFN_clCreateKernel)(
+ cl_program /* program */, const char * /* kernel_name */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clCreateKernelsInProgram)(
+ cl_program /* program */, cl_uint /* num_kernels */, cl_kernel * /* kernels */,
+ cl_uint * /* num_kernels_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clRetainKernel)(cl_kernel /* kernel */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseKernel)(cl_kernel /* kernel */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clSetKernelArg)(cl_kernel /* kernel */, cl_uint /* arg_index */,
+ size_t /* arg_size */, const void * /* arg_value */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clSetKernelArgSVMPointer)(
+ cl_kernel /* kernel */, cl_uint /* arg_index */,
+ const void * /* arg_value */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clSetKernelExecInfo)(
+ cl_kernel /* kernel */, cl_kernel_exec_info /* param_name */, size_t /* param_value_size */,
+ const void * /* param_value */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clGetKernelInfo)(
+ cl_kernel /* kernel */, cl_kernel_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetKernelArgInfo)(
+ cl_kernel /* kernel */, cl_uint /* arg_indx */, cl_kernel_arg_info /* param_name */,
+ size_t /* param_value_size */, void * /* param_value */,
+ size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clGetKernelWorkGroupInfo)(
+ cl_kernel /* kernel */, cl_device_id /* device */, cl_kernel_work_group_info /* param_name */,
+ size_t /* param_value_size */, void * /* param_value */,
+ size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clWaitForEvents)(
+ cl_uint /* num_events */, const cl_event * /* event_list */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clGetEventInfo)(
+ cl_event /* event */, cl_event_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_event(CL_API_CALL *PFN_clCreateUserEvent)(
+ cl_context /* context */, cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clRetainEvent)(cl_event /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clReleaseEvent)(cl_event /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clSetUserEventStatus)(
+ cl_event /* event */, cl_int /* execution_status */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clSetEventCallback)(
+ cl_event /* event */, cl_int /* command_exec_callback_type */,
+ void(CL_CALLBACK * /* pfn_notify */)(cl_event, cl_int, void *),
+ void * /* user_data */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clGetEventProfilingInfo)(
+ cl_event /* event */, cl_profiling_info /* param_name */, size_t /* param_value_size */,
+ void * /* param_value */, size_t * /* param_value_size_ret */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clFlush)(cl_command_queue /* command_queue */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clFinish)(cl_command_queue /* command_queue */)
+ CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, cl_bool /* blocking_read */,
+ size_t /* offset */, size_t /* size */, void * /* ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueReadBufferRect)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, cl_bool /* blocking_read */,
+ const size_t * /* buffer_offset */, const size_t * /* host_offset */, const size_t * /* region */,
+ size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, size_t /* host_row_pitch */,
+ size_t /* host_slice_pitch */, void * /* ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, cl_bool /* blocking_write */,
+ size_t /* offset */, size_t /* size */, const void * /* ptr */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteBufferRect)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, cl_bool /* blocking_write */,
+ const size_t * /* buffer_offset */, const size_t * /* host_offset */, const size_t * /* region */,
+ size_t /* buffer_row_pitch */, size_t /* buffer_slice_pitch */, size_t /* host_row_pitch */,
+ size_t /* host_slice_pitch */, const void * /* ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueFillBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, const void * /* pattern */,
+ size_t /* pattern_size */, size_t /* offset */, size_t /* size */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* src_buffer */, cl_mem /* dst_buffer */,
+ size_t /* src_offset */, size_t /* dst_offset */, size_t /* size */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferRect)(
+ cl_command_queue /* command_queue */, cl_mem /* src_buffer */, cl_mem /* dst_buffer */,
+ const size_t * /* src_origin */, const size_t * /* dst_origin */, const size_t * /* region */,
+ size_t /* src_row_pitch */, size_t /* src_slice_pitch */, size_t /* dst_row_pitch */,
+ size_t /* dst_slice_pitch */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_1;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueReadImage)(
+ cl_command_queue /* command_queue */, cl_mem /* image */, cl_bool /* blocking_read */,
+ const size_t * /* origin[3] */, const size_t * /* region[3] */, size_t /* row_pitch */,
+ size_t /* slice_pitch */, void * /* ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueWriteImage)(
+ cl_command_queue /* command_queue */, cl_mem /* image */, cl_bool /* blocking_write */,
+ const size_t * /* origin[3] */, const size_t * /* region[3] */, size_t /* input_row_pitch */,
+ size_t /* input_slice_pitch */, const void * /* ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueFillImage)(
+ cl_command_queue /* command_queue */, cl_mem /* image */, const void * /* fill_color */,
+ const size_t * /* origin[3] */, const size_t * /* region[3] */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImage)(
+ cl_command_queue /* command_queue */, cl_mem /* src_image */, cl_mem /* dst_image */,
+ const size_t * /* src_origin[3] */, const size_t * /* dst_origin[3] */,
+ const size_t * /* region[3] */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyImageToBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* src_image */, cl_mem /* dst_buffer */,
+ const size_t * /* src_origin[3] */, const size_t * /* region[3] */, size_t /* dst_offset */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueCopyBufferToImage)(
+ cl_command_queue /* command_queue */, cl_mem /* src_buffer */, cl_mem /* dst_image */,
+ size_t /* src_offset */, const size_t * /* dst_origin[3] */, const size_t * /* region[3] */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef void *(CL_API_CALL *PFN_clEnqueueMapBuffer)(
+ cl_command_queue /* command_queue */, cl_mem /* buffer */, cl_bool /* blocking_map */,
+ cl_map_flags /* map_flags */, size_t /* offset */, size_t /* size */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */, cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0;
+typedef void *(CL_API_CALL *PFN_clEnqueueMapImage)(
+ cl_command_queue /* command_queue */, cl_mem /* image */, cl_bool /* blocking_map */,
+ cl_map_flags /* map_flags */, const size_t * /* origin[3] */, const size_t * /* region[3] */,
+ size_t * /* image_row_pitch */, size_t * /* image_slice_pitch */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */, cl_int * /* errcode_ret */)CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueUnmapMemObject)(
+ cl_command_queue /* command_queue */, cl_mem /* memobj */, void * /* mapped_ptr */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueMigrateMemObjects)(
+ cl_command_queue /* command_queue */, cl_uint /* num_mem_objects */,
+ const cl_mem * /* mem_objects */, cl_mem_migration_flags /* flags */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueNDRangeKernel)(
+ cl_command_queue /* command_queue */, cl_kernel /* kernel */, cl_uint /* work_dim */,
+ const size_t * /* global_work_offset */, const size_t * /* global_work_size */,
+ const size_t * /* local_work_size */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueNativeKernel)(
+ cl_command_queue /* command_queue */, void(CL_CALLBACK * /*user_func*/)(void *),
+ void * /* args */, size_t /* cb_args */, cl_uint /* num_mem_objects */,
+ const cl_mem * /* mem_list */, const void ** /* args_mem_loc */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueMarkerWithWaitList)(
+ cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrierWithWaitList)(
+ cl_command_queue /* command_queue */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMFree)(
+ cl_command_queue /* command_queue */, cl_uint /* num_svm_pointers */,
+ void *[] /* svm_pointers[] */,
+ void(CL_CALLBACK * /*pfn_free_func*/)(cl_command_queue /* queue */,
+ cl_uint /* num_svm_pointers */,
+ void *[] /* svm_pointers[] */, void * /* user_data */),
+ void * /* user_data */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemcpy)(
+ cl_command_queue /* command_queue */, cl_bool /* blocking_copy */, void * /* dst_ptr */,
+ const void * /* src_ptr */, size_t /* size */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMemFill)(
+ cl_command_queue /* command_queue */, void * /* svm_ptr */, const void * /* pattern */,
+ size_t /* pattern_size */, size_t /* size */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMMap)(
+ cl_command_queue /* command_queue */, cl_bool /* blocking_map */, cl_map_flags /* flags */,
+ void * /* svm_ptr */, size_t /* size */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueSVMUnmap)(
+ cl_command_queue /* command_queue */, void * /* svm_ptr */, cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */, cl_event * /* event */) CL_API_SUFFIX__VERSION_2_0;
+typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddressForPlatform)(
+ cl_platform_id /* platform */, const char * /* func_name */)CL_API_SUFFIX__VERSION_1_2;
+typedef cl_mem(CL_API_CALL *PFN_clCreateImage2D)(cl_context /* context */, cl_mem_flags /* flags */,
+ const cl_image_format * /* image_format */,
+ size_t /* image_width */,
+ size_t /* image_height */,
+ size_t /* image_row_pitch */,
+ void * /* host_ptr */, cl_int * /* errcode_ret */);
+typedef cl_mem(CL_API_CALL *PFN_clCreateImage3D)(
+ cl_context /* context */, cl_mem_flags /* flags */, const cl_image_format * /* image_format */,
+ size_t /* image_width */, size_t /* image_height */, size_t /* image_depth */,
+ size_t /* image_row_pitch */, size_t /* image_slice_pitch */, void * /* host_ptr */,
+ cl_int * /* errcode_ret */);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueMarker)(cl_command_queue /* command_queue */,
+ cl_event * /* event */);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueWaitForEvents)(cl_command_queue /* command_queue */,
+ cl_uint /* num_events */,
+ const cl_event * /* event_list */);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueBarrier)(cl_command_queue /* command_queue */);
+typedef cl_int(CL_API_CALL *PFN_clUnloadCompiler)();
+typedef void *(CL_API_CALL *PFN_clGetExtensionFunctionAddress)(const char * /* func_name */);
+typedef cl_command_queue(CL_API_CALL *PFN_clCreateCommandQueue)(
+ cl_context /* context */, cl_device_id /* device */, cl_command_queue_properties /* properties */,
+ cl_int * /* errcode_ret */);
+typedef cl_sampler(CL_API_CALL *PFN_clCreateSampler)(cl_context /* context */,
+ cl_bool /* normalized_coords */,
+ cl_addressing_mode /* addressing_mode */,
+ cl_filter_mode /* filter_mode */,
+ cl_int * /* errcode_ret */);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueTask)(cl_command_queue /* command_queue */,
+ cl_kernel /* kernel */,
+ cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */,
+ cl_event * /* event */);
+
+// OpenGL sharing
+typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLBuffer)(cl_context, cl_mem_flags, cl_GLuint, int *);
+typedef cl_mem(CL_API_CALL *PFN_clCreateFromGLTexture)(
+ cl_context /* context */, cl_mem_flags /* flags */, cl_GLenum /* target */,
+ cl_GLint /* miplevel */, cl_GLuint /* texture */,
+ cl_int * /* errcode_ret */) CL_API_SUFFIX__VERSION_1_2;
+typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireGLObjects)(cl_command_queue /* command_queue */,
+ cl_uint /* num_objects */,
+ const cl_mem * /* mem_objects */,
+ cl_uint /* num_events_in_wait_list */,
+ const cl_event * /* event_wait_list */,
+ cl_event * /* event */);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseGLObjects)(
+ cl_command_queue /* command_queue */, cl_uint /* num_objects */, const cl_mem * /* mem_objects */,
+ cl_uint /* num_events_in_wait_list */, const cl_event * /* event_wait_list */,
+ cl_event * /* event */) CL_API_SUFFIX__VERSION_1_0;
+
+// cl_khr_egl_event extension
+
+// CLeglDisplayKHR is an opaque handle to an EGLDisplay
+typedef void *CLeglDisplayKHR;
+
+// CLeglSyncKHR is an opaque handle to an EGLSync object
+typedef void *CLeglSyncKHR;
+
+typedef cl_event(CL_API_CALL *PFN_clCreateEventFromEGLSyncKHR)(cl_context /* context */,
+ CLeglSyncKHR /* sync */,
+ CLeglDisplayKHR /* display */,
+ cl_int * /* errcode_ret */);
+
+// EGL sharing
+typedef cl_mem(CL_API_CALL *PFN_clCreateFromEGLImageKHR)(
+ cl_context /*context*/, CLeglDisplayKHR /*display*/, CLeglImageKHR /*image*/,
+ cl_mem_flags /*flags*/, const cl_egl_image_properties_khr * /*properties*/,
+ cl_int * /*errcode_ret*/);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueAcquireEGLObjectsKHR)(
+ cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, const cl_mem * /*mem_objects*/,
+ cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, cl_event * /*event*/);
+typedef cl_int(CL_API_CALL *PFN_clEnqueueReleaseEGLObjectsKHR)(
+ cl_command_queue /*command_queue*/, cl_uint /*num_objects*/, const cl_mem * /*mem_objects*/,
+ cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, cl_event * /*event*/);
+
+extern PFN_clGetPlatformIDs clGetPlatformIDs;
+extern PFN_clGetPlatformInfo clGetPlatformInfo;
+extern PFN_clGetDeviceIDs clGetDeviceIDs;
+extern PFN_clGetDeviceInfo clGetDeviceInfo;
+extern PFN_clCreateSubDevices clCreateSubDevices;
+extern PFN_clRetainDevice clRetainDevice;
+extern PFN_clReleaseDevice clReleaseDevice;
+extern PFN_clCreateContext clCreateContext;
+extern PFN_clCreateContextFromType clCreateContextFromType;
+extern PFN_clRetainContext clRetainContext;
+extern PFN_clReleaseContext clReleaseContext;
+extern PFN_clGetContextInfo clGetContextInfo;
+extern PFN_clCreateCommandQueueWithProperties clCreateCommandQueueWithProperties;
+extern PFN_clRetainCommandQueue clRetainCommandQueue;
+extern PFN_clReleaseCommandQueue clReleaseCommandQueue;
+extern PFN_clGetCommandQueueInfo clGetCommandQueueInfo;
+extern PFN_clCreateBuffer clCreateBuffer;
+extern PFN_clCreateSubBuffer clCreateSubBuffer;
+extern PFN_clCreateImage clCreateImage;
+extern PFN_clCreatePipe clCreatePipe;
+extern PFN_clRetainMemObject clRetainMemObject;
+extern PFN_clReleaseMemObject clReleaseMemObject;
+extern PFN_clGetSupportedImageFormats clGetSupportedImageFormats;
+extern PFN_clGetMemObjectInfo clGetMemObjectInfo;
+extern PFN_clGetImageInfo clGetImageInfo;
+extern PFN_clGetPipeInfo clGetPipeInfo;
+extern PFN_clSetMemObjectDestructorCallback clSetMemObjectDestructorCallback;
+extern PFN_clSVMAlloc clSVMAlloc;
+extern PFN_clSVMFree clSVMFree;
+extern PFN_clCreateSamplerWithProperties clCreateSamplerWithProperties;
+extern PFN_clRetainSampler clRetainSampler;
+extern PFN_clReleaseSampler clReleaseSampler;
+extern PFN_clGetSamplerInfo clGetSamplerInfo;
+extern PFN_clCreateProgramWithSource clCreateProgramWithSource;
+extern PFN_clCreateProgramWithBinary clCreateProgramWithBinary;
+extern PFN_clCreateProgramWithBuiltInKernels clCreateProgramWithBuiltInKernels;
+extern PFN_clRetainProgram clRetainProgram;
+extern PFN_clReleaseProgram clReleaseProgram;
+extern PFN_clBuildProgram clBuildProgram;
+extern PFN_clCompileProgram clCompileProgram;
+extern PFN_clLinkProgram clLinkProgram;
+extern PFN_clUnloadPlatformCompiler clUnloadPlatformCompiler;
+extern PFN_clGetProgramInfo clGetProgramInfo;
+extern PFN_clGetProgramBuildInfo clGetProgramBuildInfo;
+extern PFN_clCreateKernel clCreateKernel;
+extern PFN_clCreateKernelsInProgram clCreateKernelsInProgram;
+extern PFN_clRetainKernel clRetainKernel;
+extern PFN_clReleaseKernel clReleaseKernel;
+extern PFN_clSetKernelArg clSetKernelArg;
+extern PFN_clSetKernelArgSVMPointer clSetKernelArgSVMPointer;
+extern PFN_clSetKernelExecInfo clSetKernelExecInfo;
+extern PFN_clGetKernelInfo clGetKernelInfo;
+extern PFN_clGetKernelArgInfo clGetKernelArgInfo;
+extern PFN_clGetKernelWorkGroupInfo clGetKernelWorkGroupInfo;
+extern PFN_clWaitForEvents clWaitForEvents;
+extern PFN_clGetEventInfo clGetEventInfo;
+extern PFN_clCreateUserEvent clCreateUserEvent;
+extern PFN_clRetainEvent clRetainEvent;
+extern PFN_clReleaseEvent clReleaseEvent;
+extern PFN_clSetUserEventStatus clSetUserEventStatus;
+extern PFN_clSetEventCallback clSetEventCallback;
+extern PFN_clGetEventProfilingInfo clGetEventProfilingInfo;
+extern PFN_clFlush clFlush;
+extern PFN_clFinish clFinish;
+extern PFN_clEnqueueReadBuffer clEnqueueReadBuffer;
+extern PFN_clEnqueueReadBufferRect clEnqueueReadBufferRect;
+extern PFN_clEnqueueWriteBuffer clEnqueueWriteBuffer;
+extern PFN_clEnqueueWriteBufferRect clEnqueueWriteBufferRect;
+extern PFN_clEnqueueFillBuffer clEnqueueFillBuffer;
+extern PFN_clEnqueueCopyBuffer clEnqueueCopyBuffer;
+extern PFN_clEnqueueCopyBufferRect clEnqueueCopyBufferRect;
+extern PFN_clEnqueueReadImage clEnqueueReadImage;
+extern PFN_clEnqueueWriteImage clEnqueueWriteImage;
+extern PFN_clEnqueueFillImage clEnqueueFillImage;
+extern PFN_clEnqueueCopyImage clEnqueueCopyImage;
+extern PFN_clEnqueueCopyImageToBuffer clEnqueueCopyImageToBuffer;
+extern PFN_clEnqueueCopyBufferToImage clEnqueueCopyBufferToImage;
+extern PFN_clEnqueueMapBuffer clEnqueueMapBuffer;
+extern PFN_clEnqueueMapImage clEnqueueMapImage;
+extern PFN_clEnqueueUnmapMemObject clEnqueueUnmapMemObject;
+extern PFN_clEnqueueMigrateMemObjects clEnqueueMigrateMemObjects;
+extern PFN_clEnqueueNDRangeKernel clEnqueueNDRangeKernel;
+extern PFN_clEnqueueNativeKernel clEnqueueNativeKernel;
+extern PFN_clEnqueueMarkerWithWaitList clEnqueueMarkerWithWaitList;
+extern PFN_clEnqueueBarrierWithWaitList clEnqueueBarrierWithWaitList;
+extern PFN_clEnqueueSVMFree clEnqueueSVMFree;
+extern PFN_clEnqueueSVMMemcpy clEnqueueSVMMemcpy;
+extern PFN_clEnqueueSVMMemFill clEnqueueSVMMemFill;
+extern PFN_clEnqueueSVMMap clEnqueueSVMMap;
+extern PFN_clEnqueueSVMUnmap clEnqueueSVMUnmap;
+extern PFN_clGetExtensionFunctionAddressForPlatform clGetExtensionFunctionAddressForPlatform;
+extern PFN_clCreateImage2D clCreateImage2D;
+extern PFN_clCreateImage3D clCreateImage3D;
+extern PFN_clEnqueueMarker clEnqueueMarker;
+extern PFN_clEnqueueWaitForEvents clEnqueueWaitForEvents;
+extern PFN_clEnqueueBarrier clEnqueueBarrier;
+extern PFN_clUnloadCompiler clUnloadCompiler;
+extern PFN_clGetExtensionFunctionAddress clGetExtensionFunctionAddress;
+extern PFN_clCreateCommandQueue clCreateCommandQueue;
+extern PFN_clCreateSampler clCreateSampler;
+extern PFN_clEnqueueTask clEnqueueTask;
+
+// OpenGL sharing
+extern PFN_clCreateFromGLBuffer clCreateFromGLBuffer;
+extern PFN_clCreateFromGLTexture clCreateFromGLTexture;
+extern PFN_clEnqueueAcquireGLObjects clEnqueueAcquireGLObjects;
+extern PFN_clEnqueueReleaseGLObjects clEnqueueReleaseGLObjects;
+
+// cl_khr_egl_event extension
+extern PFN_clCreateEventFromEGLSyncKHR clCreateEventFromEGLSyncKHR;
+
+// EGL sharing
+extern PFN_clCreateFromEGLImageKHR clCreateFromEGLImageKHR;
+extern PFN_clEnqueueAcquireEGLObjectsKHR clEnqueueAcquireEGLObjectsKHR;
+extern PFN_clEnqueueReleaseEGLObjectsKHR clEnqueueReleaseEGLObjectsKHR;
+
+// For convenient image creation
+// It uses clCreateImage if it available (clCreateImage available since cl 1.2)
+// otherwise it will use legacy clCreateImage2D
+cl_mem CreateImage2DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format *image_format, const cl_image_desc *image_desc,
+ void *host_ptr, cl_int *errcode_ret);
+
+// It uses clCreateImage if it available (clCreateImage available since cl 1.2)
+// otherwise it will use legacy clCreateImage3D
+cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format *image_format, const cl_image_desc *image_desc,
+ void *host_ptr, cl_int *errcode_ret);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_WRAPPERE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Operations.h"
+#include "open_cl/Operations.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
+#include <unordered_map>
+
+#include "absl/container/flat_hash_map.h"
+
+#include "Shape.h"
+#include "Status.h"
+#include "InternalTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+Padding2D &Padding2D::operator=(const Padding2D &value)
+{
+ prepended = value.prepended;
+ appended = value.appended;
+ return *this;
+}
+
+bool Padding2D::operator==(const Padding2D &value)
+{
+ return this->prepended == value.prepended && this->appended == value.appended;
+}
+
+bool Padding2D::operator!=(const Padding2D &value) { return !(*this == value); }
+
+Padding2D &Padding2D::operator-(const Padding2D &value)
+{
+ prepended.h -= value.prepended.h;
+ prepended.w -= value.prepended.w;
+ appended.h -= value.appended.h;
+ appended.w -= value.appended.w;
+ return *this;
+}
+
+Padding3D &Padding3D::operator=(const Padding3D &value)
+{
+ prepended = value.prepended;
+ appended = value.appended;
+ return *this;
+}
+
+bool Padding3D::operator==(const Padding3D &value)
+{
+ return this->prepended == value.prepended && this->appended == value.appended;
+}
+
+bool Padding3D::operator!=(const Padding3D &value) { return !(*this == value); }
+
+Padding3D &Padding3D::operator-(const Padding3D &value)
+{
+ prepended.h -= value.prepended.h;
+ prepended.w -= value.prepended.w;
+ prepended.d -= value.prepended.d;
+ appended.h -= value.appended.h;
+ appended.w -= value.appended.w;
+ appended.d -= value.appended.d;
+ return *this;
+}
+
+std::string ToString(enum OperationType op)
+{
+ switch (op)
+ {
+ // case OperationType::ABS:
+ // return "abs";
+ case OperationType::ADD:
+ return "add";
+ // case OperationType::CONCAT:
+ // return "concat";
+ // case OperationType::COS:
+ // return "cos";
+ // case OperationType::EXP:
+ // return "exp";
+ // case OperationType::LOG:
+ // return "log";
+ // case OperationType::NEG:
+ // return "neg";
+ // case OperationType::POOLING_2D:
+ // return "pooling_2d";
+ // case OperationType::REDUCE_MAXIMUM:
+ // return "reduce_maximum";
+ // case OperationType::REDUCE_MINIMUM:
+ // return "reduce_minimum";
+ // case OperationType::REDUCE_PRODUCT:
+ // return "reduce_product";
+ // case OperationType::REDUCE_SUM:
+ // return "reduce_sum";
+ // case OperationType::RESIZE:
+ // return "resize";
+ // case OperationType::RELU:
+ // return "relu";
+ // case OperationType::RSQRT:
+ // return "rsqrt";
+ // case OperationType::SQRT:
+ // return "sqrt";
+ // case OperationType::SQUARE:
+ // return "square";
+ case OperationType::UNKNOWN:
+ return "unknown_operation";
+ }
+ return "";
+}
+
+OperationType OperationTypeFromString(const std::string &name)
+{
+ static const auto operations = new std::unordered_map<std::string, OperationType>({
+ // {"abs", OperationType::ABS},
+ {"add", OperationType::ADD},
+ // {"concat", OperationType::CONCAT},
+ // {"cos", OperationType::COS},
+ // {"exp", OperationType::EXP},
+ // {"log", OperationType::LOG},
+ // {"neg", OperationType::NEG},
+ // {"pooling_2d", OperationType::POOLING_2D},
+ // {"reduce_maximum", OperationType::REDUCE_MAXIMUM},
+ // {"reduce_minimum", OperationType::REDUCE_MINIMUM},
+ // {"reduce_product", OperationType::REDUCE_PRODUCT},
+ // {"reduce_sum", OperationType::REDUCE_SUM},
+ // {"relu", OperationType::RELU},
+ // {"resize", OperationType::RESIZE},
+ // {"rsqrt", OperationType::RSQRT},
+ // {"sqrt", OperationType::SQRT},
+ // {"square", OperationType::SQUARE},
+ });
+ auto op = operations->find(name);
+ return op == operations->end() ? OperationType::UNKNOWN : op->second;
+}
+
+namespace
+{
+
+template <typename T> T DivideRoundUp(T n, T divisor) { return (n - 1) / divisor + 1; }
+
+int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel, int32_t padding,
+ int32_t dilation)
+{
+ const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
+ return input + padding - dilated_kernel + 1;
+}
+
+template <Axis T>
+int32_t CalculateOutputWithoutStrides(const BHWC &input, const Convolution2DAttributes &attr)
+{
+ return CalculateOutputSizeBeforeStrides(
+ input.get<T>(), attr.weights.shape.get<T>(),
+ attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(), attr.dilations.get<T>());
+}
+
+template <Axis T>
+int32_t CalculateOutputWithoutStrides(const BHWDC &input, const Convolution3DAttributes &attr)
+{
+ return CalculateOutputSizeBeforeStrides(
+ input.get<T>(), attr.weights.shape.get<T>(),
+ attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(), attr.dilations.get<T>());
+}
+
+template <Axis T>
+int32_t CalculateOutputWithoutStrides(const BHWC &input, const Pooling2DAttributes &attr)
+{
+ return CalculateOutputSizeBeforeStrides(input.get<T>(), attr.kernel.get<T>(),
+ attr.padding.prepended.get<T>() +
+ attr.padding.appended.get<T>(),
+ /*dilation=*/1);
+}
+
+template <Axis T>
+int32_t CalculateOutputWithoutStrides(const BHWDC &input, const Pooling3DAttributes &attr)
+{
+ return CalculateOutputSizeBeforeStrides(input.get<T>(), attr.kernel.get<T>(),
+ attr.padding.prepended.get<T>() +
+ attr.padding.appended.get<T>(),
+ /*dilation=*/1);
+}
+
+template <Axis T>
+int32_t CalculateOutput(const BHWC &input, const ConvolutionTransposedAttributes &attr)
+{
+ return (input.get<T>() - 1) * attr.stride.get<T>() -
+ (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
+ attr.weights.shape.get<T>() + attr.adjacent.get<T>();
+}
+
+template <Axis T>
+int32_t CalculateOutput(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr)
+{
+ return (input.get<T>() - 1) * attr.stride.get<T>() -
+ (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
+ attr.weights.shape.get<T>();
+}
+
+inline int32_t StridedSize(int32_t size, int32_t stride)
+{
+ return stride == 0 ? -1 : DivideRoundUp(size, stride);
+}
+
+template <Axis AxisT, typename AttrT> int32_t CalculateOutput(const BHWC &input, const AttrT &attr)
+{
+ return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
+ attr.strides.template get<AxisT>());
+}
+
+template <Axis AxisT, typename AttrT> int32_t CalculateOutput(const BHWDC &input, const AttrT &attr)
+{
+ return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
+ attr.strides.template get<AxisT>());
+}
+
+int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation, int32_t stride)
+{
+ const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
+ return std::max(0, dilated_kernel - (input - 1) % stride - 1);
+}
+
+// Returns a padding that should be present to make sure image size stays
+// the same.
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWC &input, const Convolution2DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
+ attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
+}
+
+// Returns a padding that should be present to make sure image size stays
+// the same.
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWDC &input, const Convolution3DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
+ attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWC &input, const ConvolutionTransposedAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
+ /*dilation=*/1, attr.stride.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
+ /*dilation=*/1, attr.stride.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWC &input, const Pooling2DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
+ /*dilation=*/1, attr.strides.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWDC &input, const Pooling3DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
+ /*dilation=*/1, attr.strides.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWC &input, const MaxUnpooling2DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
+ /*dilation=*/1, attr.strides.get<AxisT>());
+}
+
+template <Axis AxisT>
+int32_t CalculateSamePadding(const BHWDC &input, const MaxUnpooling3DAttributes &attr)
+{
+ return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
+ /*dilation=*/1, attr.strides.get<AxisT>());
+}
+
+Padding2D MakeSamePadding(const BHWC &input, const ConvolutionTransposedAttributes &attr)
+{
+ int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
+ int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
+ Padding2D padding;
+ padding.prepended = HW(padding_height / 2, padding_width / 2);
+ padding.appended = HW(padding_height - padding_height / 2, padding_width - padding_width / 2);
+ return padding;
+}
+
+Padding3D MakeSamePadding(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr)
+{
+ int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
+ int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
+ int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
+ Padding3D padding;
+ padding.prepended = HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
+ padding.appended = HWD(padding_height - padding_height / 2, padding_width - padding_width / 2,
+ padding_depth - padding_depth / 2);
+ return padding;
+}
+
+// If padding depends on input, convert it into fixed padding.
+template <class AttrT> Padding2D MakeSamePadding(const BHWC &input, const AttrT &attr)
+{
+ int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
+ int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
+ Padding2D padding;
+ padding.prepended = HW(padding_height / 2, padding_width / 2);
+ padding.appended = HW(padding_height - padding_height / 2, padding_width - padding_width / 2);
+ return padding;
+}
+
+// If padding depends on input, convert it into fixed padding.
+template <class AttrT> Padding3D MakeSamePadding(const BHWDC &input, const AttrT &attr)
+{
+ int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
+ int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
+ int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
+ Padding3D padding;
+ padding.prepended = HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
+ padding.appended = HWD(padding_height - padding_height / 2, padding_width - padding_width / 2,
+ padding_depth - padding_depth / 2);
+ return padding;
+}
+
+} // namespace
+
+BHWC CalculateOutputShape(const BHWC &input, const MaxUnpooling2DAttributes &attr)
+{
+ return BHWC(
+ input.b, input.h * attr.strides.h - attr.padding.prepended.h - attr.padding.appended.h,
+ input.w * attr.strides.w - attr.padding.prepended.w - attr.padding.appended.w, input.c);
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const MaxUnpooling3DAttributes &attr)
+{
+ return BHWDC(
+ input.b, input.h * attr.strides.h - attr.padding.prepended.h - attr.padding.appended.h,
+ input.w * attr.strides.w - attr.padding.prepended.w - attr.padding.appended.w,
+ input.d * attr.strides.d - attr.padding.prepended.d - attr.padding.appended.d, input.c);
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const Pooling2DAttributes &attr)
+{
+ return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr), input.c);
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Pooling3DAttributes &attr)
+{
+ return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr), CalculateOutput<Axis::DEPTH>(input, attr),
+ input.c);
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const Convolution2DAttributes &attr)
+{
+ return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Convolution3DAttributes &attr)
+{
+ return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr), CalculateOutput<Axis::DEPTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const ConvolutionTransposedAttributes &attr)
+{
+ return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr)
+{
+ return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr), CalculateOutput<Axis::DEPTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const DepthwiseConvolution2DAttributes &attr)
+{
+ return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
+ attr.weights.shape.get<Axis::INPUT_CHANNELS>());
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const DepthwiseConvolution3DAttributes &attr)
+{
+ return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
+ CalculateOutput<Axis::WIDTH>(input, attr), CalculateOutput<Axis::DEPTH>(input, attr),
+ attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
+ attr.weights.shape.get<Axis::INPUT_CHANNELS>());
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const SliceAttributes &attr)
+{
+ (void)input;
+ return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
+ StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
+ StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
+ StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Slice3DAttributes &attr)
+{
+ (void)input;
+ return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
+ StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
+ StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
+ StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
+ StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const PadAttributes &attr)
+{
+ return BHWC(
+ attr.appended.b + attr.prepended.b + input.b, attr.appended.h + attr.prepended.h + input.h,
+ attr.appended.w + attr.prepended.w + input.w, attr.appended.c + attr.prepended.c + input.c);
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Pad3DAttributes &attr)
+{
+ return BHWDC(
+ attr.appended.b + attr.prepended.b + input.b, attr.appended.h + attr.prepended.h + input.h,
+ attr.appended.w + attr.prepended.w + input.w, attr.appended.d + attr.prepended.d + input.d,
+ attr.appended.c + attr.prepended.c + input.c);
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const FullyConnectedAttributes &attr)
+{
+ return BHWC(input.b, 1, 1, attr.weights.shape.o);
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const MeanAttributes &attr)
+{
+ const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
+ const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
+ const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
+ const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
+ return BHWC(b, h, w, c);
+}
+
+absl::Status CalculateOutputShape(const std::vector<BHWC> &input, const ConcatAttributes &attr,
+ BHWC *output_shape)
+{
+ BHWC new_shape = input[0];
+ switch (attr.axis)
+ {
+ case Axis::CHANNELS:
+ for (size_t i = 1; i < input.size(); i++)
+ {
+ if (input[i].h != new_shape.h || input[i].w != new_shape.w || input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Height, Width and Batch must be the same when concatenating "
+ "by channels axis");
+ }
+ new_shape.c += input[i].c;
+ }
+ break;
+ case Axis::HEIGHT:
+ for (size_t i = 1; i < input.size(); i++)
+ {
+ if (input[i].w != new_shape.w || input[i].c != new_shape.c || input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Channels, Width and Batch must be the same when concatenating "
+ "by height axis");
+ }
+ new_shape.h += input[i].h;
+ }
+ break;
+ case Axis::WIDTH:
+ for (size_t i = 1; i < input.size(); i++)
+ {
+ if (input[i].h != new_shape.h || input[i].c != new_shape.c || input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Height, Channels and Batch must be the same when concatenating "
+ "by width axis");
+ }
+ new_shape.w += input[i].w;
+ }
+ break;
+ case Axis::BATCH:
+ for (size_t i = 1; i < input.size(); i++)
+ {
+ if (input[i].h != new_shape.h || input[i].c != new_shape.c || input[i].w != new_shape.w)
+ {
+ return absl::InvalidArgumentError(
+ "Width, Height and Channels must be the same when concatenating "
+ "by batch axis");
+ }
+ new_shape.b += input[i].b;
+ }
+ break;
+ default:
+ return absl::InvalidArgumentError("Invalid axis");
+ break;
+ }
+ *output_shape = new_shape;
+ return absl::OkStatus();
+}
+
+absl::Status CalculateOutputShape(const std::vector<BHWDC> &input, const ConcatAttributes &attr,
+ BHWDC *output_shape)
+{
+ BHWDC new_shape = input[0];
+ switch (attr.axis)
+ {
+ case Axis::CHANNELS:
+ for (size_t i = 1; i < input.size(); ++i)
+ {
+ if (input[i].h != new_shape.h || input[i].w != new_shape.w || input[i].d != new_shape.d ||
+ input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError("Height, Width, Batch and Depth must be the same when "
+ "concatenating "
+ "by channels axis");
+ }
+ new_shape.c += input[i].c;
+ }
+ break;
+ case Axis::HEIGHT:
+ for (size_t i = 1; i < input.size(); ++i)
+ {
+ if (input[i].w != new_shape.w || input[i].c != new_shape.c || input[i].d != new_shape.d ||
+ input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Width, Depth, Batch and Channels must be the same when "
+ "concatenating "
+ "by height axis");
+ }
+ new_shape.h += input[i].h;
+ }
+ break;
+ case Axis::WIDTH:
+ for (size_t i = 1; i < input.size(); ++i)
+ {
+ if (input[i].h != new_shape.h || input[i].c != new_shape.c || input[i].d != new_shape.d ||
+ input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Height, Depth, Batch and Channels must be the same when "
+ "concatenating "
+ "by width axis");
+ }
+ new_shape.w += input[i].w;
+ }
+ break;
+ case Axis::DEPTH:
+ for (size_t i = 1; i < input.size(); ++i)
+ {
+ if (input[i].w != new_shape.w || input[i].h != new_shape.h || input[i].c != new_shape.c ||
+ input[i].b != new_shape.b)
+ {
+ return absl::InvalidArgumentError(
+ "Width, Height, Batch and Channels must be the same when "
+ "concatenating "
+ "by depth axis");
+ }
+ new_shape.d += input[i].d;
+ }
+ break;
+ case Axis::BATCH:
+ for (size_t i = 1; i < input.size(); ++i)
+ {
+ if (input[i].w != new_shape.w || input[i].h != new_shape.h || input[i].c != new_shape.c ||
+ input[i].d != new_shape.d)
+ {
+ return absl::InvalidArgumentError(
+ "Width, Height, Depth and Channels must be the same when "
+ "concatenating "
+ "by batch axis");
+ }
+ new_shape.b += input[i].b;
+ }
+ break;
+ default:
+ return absl::InvalidArgumentError("Invalid axis");
+ }
+ *output_shape = new_shape;
+ return absl::OkStatus();
+}
+
+Padding2D CalculateSamePadding(const BHWC &input, const Convolution2DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding3D CalculateSamePadding(const BHWDC &input, const Convolution3DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding2D CalculateSamePadding(const BHWC &input, const ConvolutionTransposedAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding3D CalculateSamePadding(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding2D CalculateSamePadding(const BHWC &input, const DepthwiseConvolution2DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding3D CalculateSamePadding(const BHWDC &input, const DepthwiseConvolution3DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding2D CalculateSamePadding(const BHWC &input, const Pooling2DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding3D CalculateSamePadding(const BHWDC &input, const Pooling3DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding2D CalculateSamePadding(const BHWC &input, const MaxUnpooling2DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+Padding3D CalculateSamePadding(const BHWDC &input, const MaxUnpooling3DAttributes &attr)
+{
+ return MakeSamePadding(input, attr);
+}
+
+float CalculateResizeScale(int32_t input_size, int32_t output_size, const Resize2DAttributes &attr)
+{
+ return attr.align_corners && input_size > 1 && output_size > 1
+ ? static_cast<float>(input_size - 1) / (output_size - 1)
+ : static_cast<float>(input_size) / output_size;
+}
+
+float CalculateResizeScale(int32_t input_size, int32_t output_size, const Resize3DAttributes &attr)
+{
+ return attr.align_corners && input_size > 1 && output_size > 1
+ ? static_cast<float>(input_size - 1) / (output_size - 1)
+ : static_cast<float>(input_size) / output_size;
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const Resize2DAttributes &attr)
+{
+ return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Resize3DAttributes &attr)
+{
+ return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d, input.c);
+}
+
+BHWC CalculateOutputShape(const BHWC &input, const TransposeAttributes &attr)
+{
+ return BHWC(input.get(attr.perm.b), input.get(attr.perm.h), input.get(attr.perm.w),
+ input.get(attr.perm.c));
+}
+
+BHWDC CalculateOutputShape(const BHWDC &input, const Transpose3DAttributes &attr)
+{
+ return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h), input.get(attr.perm.w),
+ input.get(attr.perm.d), input.get(attr.perm.c));
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_OPERATIONS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_OPERATIONS_H__
+
+#include <cstdint>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "absl/types/variant.h"
+
+#include "DataType.h"
+#include "Shape.h"
+#include "Status.h"
+#include "InternalTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class OperationType
+{
+ UNKNOWN = 0,
+ // ABS,
+ ADD,
+ // BATCH_TO_SPACE,
+ // BATCH_NORMALIZATION,
+ // BATCHED_MATMUL,
+ // CONCAT,
+ // CONST,
+ // CONVOLUTION_2D,
+ // CONVOLUTION_TRANSPOSED,
+ // COPY,
+ // COS,
+ // DEPTHWISE_CONVOLUTION,
+ // DIV,
+ // ELU,
+ // EQUAL,
+ // EXP,
+ // FULLY_CONNECTED,
+ // GREATER,
+ // GREATER_EQUAL,
+ // HARD_SWISH,
+ // LESS,
+ // LESS_EQUAL,
+ // LOG,
+ // LSTM,
+ // MAXIMUM,
+ // MAX_UNPOOLING_2D,
+ // MEAN,
+ // MEAN_STDDEV_NORMALIZATION,
+ // MINIMUM,
+ // MUL,
+ // NEG,
+ // NOT_EQUAL,
+ // PAD,
+ // POOLING_2D,
+ // POW,
+ // PRELU,
+ // Used to accurately run inference on quantized models.
+ // QUANTIZE_AND_DEQUANTIZE,
+ // REDUCE_MAXIMUM,
+ // REDUCE_MINIMUM,
+ // REDUCE_PRODUCT,
+ // REDUCE_SUM,
+ // RELU,
+ // RESHAPE,
+ // RESIZE,
+ // RSQRT,
+ // SIGMOID,
+ // SIN,
+ // SLICE,
+ // SOFTMAX,
+ // SPACE_TO_BATCH,
+ // SPACE_TO_DEPTH,
+ // SQRT,
+ // SQUARE,
+ // SQUARED_DIFF,
+ // SUB,
+ // TANH,
+ // TRANSPOSE,
+};
+
+std::string ToString(enum OperationType op);
+
+OperationType OperationTypeFromString(const std::string &name);
+
+typedef absl::variant<absl::monostate, InternalTensor<HWC, DataType::FLOAT32>,
+ InternalTensor<Linear, DataType::FLOAT32>, float>
+ TensorOrScalar;
+
+struct Padding2D
+{
+ Padding2D() = default;
+ Padding2D(const Padding2D &);
+ Padding2D &operator=(const Padding2D &value);
+ bool operator==(const Padding2D &value);
+ bool operator!=(const Padding2D &value);
+ Padding2D &operator-(const Padding2D &value);
+
+ // Padding values for every axis (if needed), where 'prepended' defines
+ // padding for the beginning of each axis and 'appended' represents end part
+ // of the corresponding axis.
+ HW prepended = HW(-1, -1);
+ HW appended = HW(-1, -1);
+};
+
+struct Padding3D
+{
+ Padding3D() = default;
+ Padding3D(const Padding3D &);
+ Padding3D &operator=(const Padding3D &value);
+ bool operator==(const Padding3D &value);
+ bool operator!=(const Padding3D &value);
+ Padding3D &operator-(const Padding3D &value);
+ // Padding values for every axis (if needed), where 'prepended' defines
+ // padding for the beginning of each axis and 'appended' represents end part
+ // of the corresponding axis.
+ HWD prepended = HWD(0, 0, 0);
+ HWD appended = HWD(0, 0, 0);
+};
+
+struct Crop2D : public Padding2D
+{
+};
+
+struct SpaceToBatchAttributes
+{
+ HW block;
+ Padding2D padding;
+};
+
+struct BatchToSpaceAttributes
+{
+ HW block;
+ Crop2D crop;
+};
+
+enum class PoolingType
+{
+ UNDEFINED = 0,
+
+ // average pooling
+ AVERAGE = 1,
+
+ // max pooling
+ MAX = 2,
+};
+
+struct Pooling2DAttributes
+{
+ PoolingType type = PoolingType::UNDEFINED;
+ // Strides for every axis.
+ HW strides = HW(-1, -1);
+ HW kernel = HW(-1, -1);
+ Padding2D padding;
+ // NOTE(akulik): technically the number of outputs from Pooling node indicates
+ // whether indices are needed or not, but I decided to keep it inside
+ // attributes to simplify processing.
+ bool output_indices = false;
+};
+
+struct Pooling3DAttributes
+{
+ PoolingType type = PoolingType::UNDEFINED;
+ // Strides for every axis.
+ HWD strides = HWD(0, 0, 0);
+ HWD kernel = HWD(0, 0, 0);
+ Padding3D padding;
+ // NOTE(akulik): technically the number of outputs from Pooling node indicates
+ // whether indices are needed or not, but I decided to keep it inside
+ // attributes to simplify processing.
+ bool output_indices = false;
+};
+
+struct MaxUnpooling2DAttributes
+{
+ // Strides for every axis.
+ HW strides = HW(-1, -1);
+ HW kernel = HW(-1, -1);
+ Padding2D padding;
+};
+
+struct MaxUnpooling3DAttributes
+{
+ // Strides for every axis.
+ HWD strides = HWD(0, 0, 0);
+ HWD kernel = HWD(0, 0, 0);
+ Padding3D padding;
+};
+
+struct MeanAttributes
+{
+ // The vector of dimensions to calculate mean along.
+ std::set<Axis> dims;
+};
+
+struct ConcatAttributes
+{
+ // Defines axis by which to concat on.
+ Axis axis = Axis::UNKNOWN;
+};
+
+// @return shape of a tensor after MaxUnpooling2D operation is applied to
+// the given input.
+BHWC CalculateOutputShape(const BHWC &input, const MaxUnpooling2DAttributes &attr);
+
+// @return shape of a tensor after MaxUnpooling3D operation is applied to
+// the given input.
+BHWDC CalculateOutputShape(const BHWDC &input, const MaxUnpooling3DAttributes &attr);
+
+// @return shape of a tensor after Pooling2D operation is applied to the given
+// input.
+BHWC CalculateOutputShape(const BHWC &input, const Pooling2DAttributes &attr);
+
+// @return shape of a tensor after Pooling3D operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Pooling3DAttributes &attr);
+
+// @return shape of a tensor after Concat operation is applied to the given
+// input.
+absl::Status CalculateOutputShape(const std::vector<BHWC> &input, const ConcatAttributes &attr,
+ BHWC *output_shape);
+
+// @return shape of a tensor after Concat operation is applied to the given
+// input.
+absl::Status CalculateOutputShape(const std::vector<BHWDC> &input, const ConcatAttributes &attr,
+ BHWDC *output_shape);
+
+// @return padding for pooling operation to make sure output keep the same shape
+// as the given input.
+Padding2D CalculateSamePadding(const BHWC &input, const Pooling2DAttributes &attr);
+
+// @return padding for pooling operation to make sure output keep the same shape
+// as the given input.
+Padding3D CalculateSamePadding(const BHWDC &input, const Pooling3DAttributes &attr);
+
+// @return padding for max unpooling operation to make sure output keep the same
+// shape as the given input.
+Padding2D CalculateSamePadding(const BHWC &input, const MaxUnpooling2DAttributes &attr);
+
+// @return padding for max unpooling operation to make sure output keep the same
+// shape as the given input.
+Padding3D CalculateSamePadding(const BHWDC &input, const MaxUnpooling3DAttributes &attr);
+
+struct Convolution2DAttributes
+{
+ HW strides = HW(1, 1); // Along each axis.
+ HW dilations = HW(1, 1); // Along each axis.
+ Padding2D padding;
+
+ InternalTensor<OHWI, DataType::FLOAT32> weights;
+ InternalTensor<Linear, DataType::FLOAT32> bias; // optional
+};
+
+struct Convolution3DAttributes
+{
+ HWD strides = HWD(0, 0, 0); // Along each axis.
+ HWD dilations = HWD(0, 0, 0); // Along each axis.
+ Padding3D padding;
+
+ InternalTensor<OHWDI, DataType::FLOAT32> weights;
+ InternalTensor<Linear, DataType::FLOAT32> bias; // optional
+};
+
+// @return shape of a tensor after Convolution2D operation is applied to
+// the given input.
+BHWC CalculateOutputShape(const BHWC &input, const Convolution2DAttributes &attr);
+
+// @return shape of a tensor after Convolution3D operation is applied to
+// the given input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Convolution3DAttributes &attr);
+
+// @return padding for convolution operation to make sure output keep the same
+// shape as the given input.
+Padding2D CalculateSamePadding(const BHWC &input, const Convolution2DAttributes &attr);
+
+// @return padding for convolution operation to make sure output keep the same
+// shape as the given input.
+Padding3D CalculateSamePadding(const BHWDC &input, const Convolution3DAttributes &attr);
+
+struct ConvolutionTransposedAttributes
+{
+ HW stride = HW(1, 1); // Along each axis.
+ HW adjacent; // TODO(sorokin): No op on Flow.
+ Padding2D padding;
+
+ InternalTensor<OHWI, DataType::FLOAT32> weights;
+ InternalTensor<Linear, DataType::FLOAT32> bias; // optional
+};
+
+struct ConvolutionTransposed3DAttributes
+{
+ HWD stride = HWD(0, 0, 0); // Along each axis.
+ Padding3D padding;
+
+ InternalTensor<OHWDI, DataType::FLOAT32> weights;
+ InternalTensor<Linear, DataType::FLOAT32> bias; // optional
+};
+
+Padding2D CalculateSamePadding(const BHWC &input, const ConvolutionTransposedAttributes &attr);
+
+Padding3D CalculateSamePadding(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr);
+
+// @return shape of a tensor after ConvolutionTransposed operation is applied to
+// the given input.
+BHWC CalculateOutputShape(const BHWC &input, const ConvolutionTransposedAttributes &attr);
+
+// @return shape of a tensor after ConvolutionTransposed3D operation is applied
+// to
+// the given input.
+BHWDC CalculateOutputShape(const BHWDC &input, const ConvolutionTransposed3DAttributes &attr);
+
+struct DepthwiseConvolution2DAttributes : public Convolution2DAttributes
+{
+};
+struct DepthwiseConvolution3DAttributes : public Convolution3DAttributes
+{
+};
+
+// @return shape of a tensor after DepthwiseConvolution2D operation is applied
+// to the given input.
+BHWC CalculateOutputShape(const BHWC &input, const DepthwiseConvolution2DAttributes &attr);
+
+// @return shape of a tensor after DepthwiseConvolution3D operation is applied
+// to the given input.
+BHWDC CalculateOutputShape(const BHWDC &input, const DepthwiseConvolution3DAttributes &attr);
+
+// @return padding for depthwise convolution operation to make sure output keep
+// the same shape as the given input.
+Padding2D CalculateSamePadding(const BHWC &input, const DepthwiseConvolution2DAttributes &attr);
+
+// @return padding for depthwise convolution operation to make sure output keep
+// the same shape as the given input.
+Padding3D CalculateSamePadding(const BHWDC &input, const DepthwiseConvolution3DAttributes &attr);
+
+// f(x):= {
+// if x < 0 : x -> alpha * x
+// if x >= 0 : x -> min(clip, x)
+// }
+//
+// Examples:
+// - ReLU: clip = 0, alpha = 0
+// - ReLU6: clip = 6, alpha = 0
+// - Leaky ReLU: clip = 0, alpha = a
+struct ReLUAttributes
+{
+ // clip <= 0 mean it is not set.
+ float clip = 0;
+
+ float alpha = 0;
+};
+
+struct PReLUAttributes
+{
+ // clip <= 0 mean it is not set.
+ float clip = 0;
+
+ // If alpha is linear, then it is sharded across CHANNELS axis, otherwise
+ // full shape alpha is required.
+ absl::variant<InternalTensor<Linear, DataType::FLOAT32>, InternalTensor<HWC, DataType::FLOAT32>>
+ alpha;
+};
+
+struct ReduceAttributes
+{
+ Axis axis = Axis::UNKNOWN;
+};
+
+struct SoftmaxAttributes
+{
+ Axis axis = Axis::UNKNOWN;
+};
+
+enum LstmKernelType
+{
+ FULL = 0,
+ BASIC = 1, // Currently, only basic is supported.
+};
+
+struct LstmAttributes
+{
+ LstmKernelType kernel_type = LstmKernelType::BASIC;
+};
+
+enum class SamplingType
+{
+ UNKNOWN = 0,
+ NEAREST = 1,
+ BILINEAR = 2,
+};
+
+struct Resize2DAttributes
+{
+ HW new_shape;
+
+ SamplingType type = SamplingType::UNKNOWN;
+
+ // If true, the centers of the 4 corner pixels of the input and output tensors
+ // are aligned, preserving the values at the corner pixels. Defaults to false.
+ bool align_corners = false;
+
+ bool half_pixel_centers = false;
+};
+
+// TODO(b/147771327): rename to Resize3D
+struct Resize3DAttributes
+{
+ HWD new_shape;
+
+ SamplingType type = SamplingType::NEAREST;
+
+ // If true, the centers of the 8 corner pixels of the input and output tensors
+ // are aligned, preserving the values at the corner pixels. Defaults to false.
+ bool align_corners = false;
+
+ bool half_pixel_centers = false;
+};
+
+float CalculateResizeScale(int32_t input_size, int32_t output_size, const Resize2DAttributes &attr);
+
+float CalculateResizeScale(int32_t input_size, int32_t output_size, const Resize3DAttributes &attr);
+
+// @return shape of a tensor after scale operation is applied to the given
+// input.
+BHWC CalculateOutputShape(const BHWC &input, const Resize2DAttributes &attr);
+
+// @return shape of a tensor after scale operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Resize3DAttributes &attr);
+
+enum class PaddingContentType
+{
+ ZEROS = 0,
+ REFLECT = 1,
+ EDGE = 2,
+};
+
+struct PadAttributes
+{
+ PaddingContentType type = PaddingContentType::ZEROS;
+
+ BHWC prepended;
+ BHWC appended;
+};
+
+// @return shape of a tensor after Pad operation is applied to the given input.
+BHWC CalculateOutputShape(const BHWC &input, const PadAttributes &attr);
+
+struct Pad3DAttributes
+{
+ PaddingContentType type = PaddingContentType::ZEROS;
+
+ BHWDC prepended;
+ BHWDC appended;
+};
+
+// @return shape of a tensor after Pad3D operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Pad3DAttributes &attr);
+
+struct ConstTensorAttributes
+{
+ InternalTensor<BHWC, DataType::FLOAT32> tensor;
+};
+
+// Simple slicing without advanced support for shrinking, reverse slicing etc.
+struct SliceAttributes
+{
+ // Specifies start and end dimensions for slicing.
+ BHWC starts;
+ BHWC ends;
+
+ // Stride should be >= 1.
+ BHWC strides;
+};
+
+// @return shape of a tensor after Slice2D operation is applied to the given
+// input.
+BHWC CalculateOutputShape(const BHWC &input, const SliceAttributes &attr);
+
+// Simple slicing without advanced support for shrinking, reverse slicing etc.
+struct Slice3DAttributes
+{
+ // Specifies start and end dimensions for slicing.
+ BHWDC starts;
+ BHWDC ends;
+
+ // Stride should be >= 1.
+ BHWDC strides;
+};
+
+// @return shape of a tensor after Slice3D operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Slice3DAttributes &attr);
+
+struct FullyConnectedAttributes
+{
+ InternalTensor<OHWI, DataType::FLOAT32> weights;
+ InternalTensor<Linear, DataType::FLOAT32> bias;
+};
+
+// @return shape of a tensor after FullyConnected operation is applied to
+// the given input.
+BHWC CalculateOutputShape(const BHWC &input, const FullyConnectedAttributes &attr);
+
+// @return shape of a tensor after Mean operation is applied to the given input.
+BHWC CalculateOutputShape(const BHWC &input, const MeanAttributes &attr);
+
+struct ElementwiseAttributes
+{
+ TensorOrScalar param;
+ // For elementwise operation with 2 inputs op(A, B), runtime_tensor_is_second
+ // true when runtime tensor is B(on second position). this is important for
+ // ops that non commutative, for example substract.
+ bool runtime_tensor_is_second = false;
+};
+
+struct ReshapeAttributes
+{
+ BHWC new_shape;
+};
+
+struct Reshape3DAttributes
+{
+ BHWDC new_shape;
+};
+
+struct TransposeAttributes
+{
+ // A permutation of the dimensions of input tensor
+ BHWC perm;
+};
+
+// @return shape of a tensor after Transpose operation is applied to
+// the given input.
+BHWC CalculateOutputShape(const BHWC &input, const TransposeAttributes &attr);
+
+struct Transpose3DAttributes
+{
+ // A permutation of the dimensions of input tensor
+ BHWDC perm;
+};
+
+// @return shape of a tensor after Transpose3D operation is applied to
+// the given input.
+BHWDC CalculateOutputShape(const BHWDC &input, const Transpose3DAttributes &attr);
+
+struct SpaceToDepthAttributes
+{
+ int block_size;
+};
+
+// These help perform a combination of Quantize & Dequantize to adjust float
+// values like quantized inference would.
+struct QuantizeAndDequantizeAttributes
+{
+ float min = 0;
+ float max = 0;
+ float scale = 0;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_OPERATIONS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Precision.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::string ToString(CalculationsPrecision precision)
+{
+ switch (precision)
+ {
+ case CalculationsPrecision::F32_F16:
+ return "CalculationsPrecision::F32_F16";
+ case CalculationsPrecision::F32:
+ return "CalculationsPrecision::F32";
+ case CalculationsPrecision::F16:
+ return "CalculationsPrecision::F16";
+ }
+ return " ";
+}
+
+DataType DeduceDataTypeFromPrecision(CalculationsPrecision precision)
+{
+ if (precision == CalculationsPrecision::F32)
+ {
+ return DataType::FLOAT32;
+ }
+ else
+ {
+ return DataType::FLOAT16;
+ }
+ return DataType::UNKNOWN;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_PRECISION_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_PRECISION_H__
+
+#include <string>
+
+#include "DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class CalculationsPrecision
+{
+ F32,
+ F32_F16,
+ F16
+};
+// F32 - all data and all math ops in F32
+// F16 - all data and all math ops in F16
+// F32_F16 - as F16, but some operations (Convolution,
+// DepthwiseConvolution, FullyConnected, ConvolutionTransposed)
+// have accumulator in F32 and usually it calculates 4 mads in F16, sum them,
+// than converts this partial sum to F32 and add to accumulator.
+
+DataType DeduceDataTypeFromPrecision(CalculationsPrecision precision);
+
+std::string ToString(CalculationsPrecision precision);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_PRECISION_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ProgramCache.h"
+
+#include <cstdint>
+#include <string>
+
+#include "ClProgram.h"
+#include "Status.h"
+#include "Util.h"
+#include "farmhash.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ProgramCache::ProgramDescriptor::ProgramDescriptor(const std::string &code_text,
+ const std::string &options,
+ bool use_fingerprints)
+ : code(code_text), compiler_options(options), use_fingerprint(use_fingerprints)
+{
+ const uint64_t code_fingerprint = ::util::Fingerprint64(code);
+ const uint64_t options_fingerprint = ::util::Fingerprint64(compiler_options);
+ fingerprint = code_fingerprint + options_fingerprint;
+}
+
+ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
+ : fingerprint(fingerprints), use_fingerprint(true)
+{
+}
+
+ProgramCache::ProgramCache(ProgramCache &&program_cache)
+ : use_fingerprints_(program_cache.use_fingerprints_),
+ programs_(std::move(program_cache.programs_))
+{
+}
+
+ProgramCache &ProgramCache::operator=(ProgramCache &&program_cache)
+{
+ if (this != &program_cache)
+ {
+ use_fingerprints_ = program_cache.use_fingerprints_;
+ programs_ = std::move(program_cache.programs_);
+ }
+ return *this;
+}
+
+absl::Status ProgramCache::GetOrCreateCLKernel(const std::string &code,
+ const std::string &function_name,
+ const std::vector<CompilerOptions> &compiler_options,
+ const CLContext &context, const CLDevice &device,
+ CLKernel *result)
+{
+ const std::string options = CompilerOptionsToString(device, compiler_options);
+ ProgramDescriptor desc{code, options, use_fingerprints_};
+ auto it = programs_.find(desc);
+ if (it != programs_.end())
+ {
+ return result->CreateFromProgram(it->second, function_name);
+ }
+
+ CLProgram program;
+ RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
+ RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
+ programs_.insert(std::make_pair(std::move(desc), std::move(program)));
+ return absl::OkStatus();
+}
+
+absl::Status ProgramCache::GetOrCreateCLKernel(const std::string &code,
+ const std::string &function_name,
+ const CLContext &context, const CLDevice &device,
+ CLKernel *result)
+{
+ return GetOrCreateCLKernel(code, function_name, {}, context, device, result);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_PROGRAM_CACHE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_PROGRAM_CACHE_H__
+
+#include <cstdint>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/types/span.h"
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "ClKernel.h"
+#include "ClProgram.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ProgramCache
+{
+public:
+ ProgramCache() = default;
+
+ ProgramCache(ProgramCache &&program_cache);
+ ProgramCache &operator=(ProgramCache &&program_cache);
+ ProgramCache(const ProgramCache &) = delete;
+ ProgramCache &operator=(const ProgramCache &) = delete;
+
+ absl::Status GetOrCreateCLKernel(const std::string &code, const std::string &function_name,
+ const std::vector<CompilerOptions> &compiler_options,
+ const CLContext &context, const CLDevice &device,
+ CLKernel *result);
+
+ absl::Status GetOrCreateCLKernel(const std::string &code, const std::string &function_name,
+ const CLContext &context, const CLDevice &device,
+ CLKernel *result);
+
+private:
+ struct ProgramDescriptor
+ {
+ ProgramDescriptor() = default;
+ ProgramDescriptor(const std::string &code_text, const std::string &options,
+ bool use_fingerprint);
+ explicit ProgramDescriptor(uint64_t fingerprint);
+
+ std::string code;
+ std::string compiler_options;
+ uint64_t fingerprint;
+ bool use_fingerprint;
+ };
+ struct ProgramDescriptorHasher
+ {
+ std::size_t operator()(const ProgramDescriptor &k) const
+ {
+ if (k.use_fingerprint)
+ {
+ return std::hash<uint64_t>()(k.fingerprint);
+ }
+ else
+ {
+ return std::hash<std::string>()(k.code) + std::hash<std::string>()(k.compiler_options);
+ }
+ }
+ };
+ struct ProgramDescriptorEqual
+ {
+ bool operator()(const ProgramDescriptor &a, const ProgramDescriptor &b) const
+ {
+ if (a.use_fingerprint && b.use_fingerprint)
+ {
+ return a.fingerprint == b.fingerprint;
+ }
+ else
+ {
+ return a.compiler_options == b.compiler_options && a.code == b.code;
+ }
+ }
+ };
+
+ // There is a low probability of a hash collision when cache is deserialized
+ // because only fingerprints are serialized instead of full source code.
+ bool use_fingerprints_ = false;
+ absl::flat_hash_map<ProgramDescriptor, CLProgram, ProgramDescriptorHasher, ProgramDescriptorEqual>
+ programs_;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_PROGRAM_CACHE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Shape.h"
+
+#include <stdint.h>
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+struct GetAxisByIndexFunc
+{
+ template <Layout T> Axis operator()() const { return GetAxis<T>(index); }
+ int32_t index;
+};
+
+struct GetIndexByAxisFunc
+{
+ template <Layout T> int operator()() const { return GetAxisIndex<T>(axis); }
+ Axis axis;
+};
+
+struct NumAxisFunc
+{
+ template <Layout T> int operator()() const { return Size<T>(); }
+};
+
+} // namespace
+
+std::string ToString(Axis axis)
+{
+ switch (axis)
+ {
+ case Axis::BATCH:
+ return "batch";
+ case Axis::CHANNELS:
+ return "channels";
+ case Axis::INPUT_CHANNELS:
+ return "input_channels";
+ case Axis::OUTPUT_CHANNELS:
+ return "output_channels";
+ case Axis::HEIGHT:
+ return "height";
+ case Axis::WIDTH:
+ return "width";
+ case Axis::VALUE:
+ return "value";
+ case Axis::DEPTH:
+ return "depth";
+ case Axis::UNKNOWN:
+ return "unknown";
+ }
+ return "undefined";
+}
+
+std::string ToString(Layout layout)
+{
+ switch (layout)
+ {
+ case Layout::SCALAR:
+ return "scalar";
+ case Layout::LINEAR:
+ return "linear";
+ case Layout::HW:
+ return "hw";
+ case Layout::HWD:
+ return "hwd";
+ case Layout::CHW:
+ return "chw";
+ case Layout::HWC:
+ return "hwc";
+ case Layout::HWDC:
+ return "hwdc";
+ case Layout::OHWI:
+ return "ohwi";
+ case Layout::IHWO:
+ return "ihwo";
+ case Layout::OIHW:
+ return "oihw";
+ case Layout::IOHW:
+ return "iohw";
+ case Layout::BHWC:
+ return "bhwc";
+ case Layout::BHWDC:
+ return "bhwdc";
+ case Layout::OHWDI:
+ return "ohwi";
+ case Layout::UNKNOWN:
+ return "unknown";
+ }
+ return "undefined";
+}
+
+Axis GetAxis(Layout layout, int32_t index)
+{
+ return DispatchByLayout(layout, GetAxisByIndexFunc{index});
+}
+
+int GetAxisIndex(Layout layout, Axis axis)
+{
+ return DispatchByLayout(layout, GetIndexByAxisFunc{axis});
+}
+
+bool HasAxis(Layout layout, Axis axis) { return GetAxisIndex(layout, axis) >= 0; }
+
+int Size(Layout layout) { return DispatchByLayout(layout, NumAxisFunc()); }
+
+std::string ToString(const Shape &s)
+{
+ return absl::StrCat("{", ToString(s.layout), ", {", absl::StrJoin(s.dimensions, ", "), "}}");
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_SHAPE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_SHAPE_H__
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <array>
+#include <functional>
+#include <numeric>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class Axis
+{
+ UNKNOWN = 0,
+ CHANNELS = 1,
+ INPUT_CHANNELS = 2,
+ OUTPUT_CHANNELS = 3,
+ HEIGHT = 4,
+ WIDTH = 5,
+ BATCH = 6,
+ VALUE = 7,
+ DEPTH = 8,
+};
+
+std::string ToString(Axis t);
+
+// Layout represents axis order.
+enum class Layout
+{
+ UNKNOWN = 0,
+ SCALAR = 1,
+ LINEAR = 2,
+ HW = 3,
+ CHW = 4,
+ HWC = 5,
+ OIHW = 6,
+ OHWI = 7,
+ IHWO = 8,
+ IOHW = 9,
+ BHWC = 10,
+ HWDC = 11,
+ BHWDC = 12,
+ HWD = 13,
+ OHWDI = 14,
+};
+
+std::string ToString(Layout l);
+
+// Returns number of axis for the fixed layout.
+template <Layout T> constexpr int Size();
+
+// Returns number of axis for the given layout.
+int Size(Layout layout);
+
+// Returns Axis for the given index and fixed layout.
+template <Layout T> constexpr Axis GetAxis(int index);
+
+// Returns axis for the given layout and index.
+Axis GetAxis(Layout layout, int32_t index);
+
+// Returns axis index for the given axis and fixed layout.
+template <Layout T> constexpr int GetAxisIndex(Axis axis);
+
+// Returns axis index for the given layout and axis.
+int GetAxisIndex(Layout layout, Axis axis);
+
+// Checks if fixed layout has given axis
+template <Layout T> constexpr bool HasAxis(Axis axis);
+
+// Checks if given layout has given axis
+bool HasAxis(Layout layout, Axis axis);
+
+// Stores Layout(axis set and order) and value for dimensions.
+struct Shape
+{
+ Shape() : layout(Layout::UNKNOWN), dimensions() {}
+
+ explicit Shape(Layout t) : layout(t), dimensions(Size(t)) {}
+
+ Shape(Layout t, std::vector<int32_t> d) : layout(t), dimensions(std::move(d)) {}
+
+ bool operator==(const Shape &other) const
+ {
+ return (layout == other.layout) && (dimensions == other.dimensions);
+ }
+
+ bool operator!=(const Shape &other) const { return !operator==(other); }
+
+ // All methods below are matching same methods defined in StrongShape to
+ // make sure generic algorithms work both ways.
+
+ // Returns back a dimension or -1 if it is not found.
+ template <Axis D> int32_t get() const;
+ int32_t get(Axis axis) const;
+
+ template <Axis D> bool set(int32_t t);
+ bool set(Axis axis, int32_t t);
+
+ Axis axis(int index) const { return GetAxis(layout, index); }
+
+ int index(Axis axis) const { return GetAxisIndex(layout, axis); }
+
+ bool has(Axis axis) const { return HasAxis(layout, axis); }
+
+ int64_t DimensionsProduct() const
+ {
+ return std::accumulate(dimensions.begin(), dimensions.end(), 1ll, std::multiplies<int64_t>());
+ }
+
+ Layout layout = Layout::UNKNOWN;
+
+ std::vector<int32_t> dimensions;
+};
+
+std::string ToString(const Shape &s);
+
+// StrongShape provides convenient explicit access to dimensions stored in
+// shape, e.g. StrongShape<Layout::HW> s; provides s.h and s.w accessors.
+//
+// There is a conversion possible both ways between Shape and StrongShape.
+//
+// OIHW oihw; // specific shape
+// Shape l = oihw.ToShape();
+//
+// OHWI other; // notice not the same but compatible shape.
+// if (!other.Adopt(l)) {
+// // error handling
+// }
+//
+// StrongShape supports the following set of operations:
+//
+// // Returns number of axis in the shape class.
+// static constexpr int size();
+//
+// // Returns Axis for the given index or Axis::UNKNOWN if index
+// // falls outside of the defined range in this shape.
+// static constexpr Axis axis(int index);
+//
+// // Returns index for the given axis or -1 if axis is not defined in this
+// // shape.
+// static constexpr int index(Axis axis);
+//
+// // Getters
+// int32_t get(int index) const;
+// int32_t get(Axis axis) const;
+// int32_t get<Axis>() const;
+//
+// // Setters that return false if set was not successful.
+// bool set(int index, int32_t v);
+// bool set(Axis axis, int32_t v);
+// bool set<Axis>(int32_t v);
+//
+// // Returns shape's layout.
+// static const Layout layout;
+//
+// // Turns specific shape into generic shape.
+// Shape ToShape() const;
+//
+// // Copies all dimensions from the given shape.
+// bool Adopt(const Shape&);
+//
+template <Layout L> struct StrongShape;
+
+using Scalar = StrongShape<Layout::SCALAR>;
+using Linear = StrongShape<Layout::LINEAR>;
+using HW = StrongShape<Layout::HW>;
+using HWD = StrongShape<Layout::HWD>;
+
+// Common tensor shape for CNN models working with images.
+using CHW = StrongShape<Layout::CHW>;
+using HWC = StrongShape<Layout::HWC>;
+using HWDC = StrongShape<Layout::HWDC>;
+using BHWC = StrongShape<Layout::BHWC>;
+using BHWDC = StrongShape<Layout::BHWDC>;
+
+// Tensor shape used in convolution_2d weights.
+using OIHW = StrongShape<Layout::OIHW>;
+using OHWI = StrongShape<Layout::OHWI>;
+using IHWO = StrongShape<Layout::IHWO>;
+using IOHW = StrongShape<Layout::IOHW>;
+
+// Tensor shape used in convolution_3d weights.
+using OHWDI = StrongShape<Layout::OHWDI>;
+
+// -----------------------------------------------------------------------------
+// Everything below are internal implementation details.
+// -----------------------------------------------------------------------------
+
+namespace internal_shape
+{
+
+template <Axis T> struct AxisTraits;
+
+#define TFLITE_GPU_AXIS_TRAITS(AxisName, HolderName) \
+ template <> struct AxisTraits<Axis::AxisName> \
+ { \
+ struct Holder \
+ { \
+ int32_t HolderName; \
+ \
+ protected: \
+ int32_t operator()() const { return HolderName; } \
+ void operator()(int32_t v) { HolderName = v; } \
+ }; \
+ \
+ using dimension_holder_type = Holder; \
+ }
+
+TFLITE_GPU_AXIS_TRAITS(CHANNELS, c);
+TFLITE_GPU_AXIS_TRAITS(HEIGHT, h);
+TFLITE_GPU_AXIS_TRAITS(WIDTH, w);
+TFLITE_GPU_AXIS_TRAITS(INPUT_CHANNELS, i);
+TFLITE_GPU_AXIS_TRAITS(OUTPUT_CHANNELS, o);
+TFLITE_GPU_AXIS_TRAITS(BATCH, b);
+TFLITE_GPU_AXIS_TRAITS(VALUE, v);
+TFLITE_GPU_AXIS_TRAITS(DEPTH, d);
+
+#undef TFLITE_GPU_AXIS_TRAITS
+
+template <int N, Axis... As> struct StrongShapeImpl;
+
+template <int N> struct StrongShapeImpl<N>
+{
+ static constexpr int size() { return N; }
+
+ static constexpr Axis axis(int) { return Axis::UNKNOWN; }
+
+ static constexpr int index(Axis) { return -1; }
+
+ static constexpr bool has(Axis) { return false; }
+
+ int32_t get(Axis) const { return -1; }
+
+ int32_t get(int) const { return -1; }
+
+ template <Axis B> int32_t get() const { return -1; }
+
+ bool set(Axis, int32_t) { return false; }
+
+ bool set(int, int32_t) { return false; }
+
+ template <Axis B> bool set(int32_t) { return false; }
+};
+
+// Used to deduce number of axis, and to be a child of a proper holder to
+// provide access to the dimension by name
+template <int N, Axis A, Axis... As>
+struct StrongShapeImpl<N, A, As...> : public AxisTraits<A>::dimension_holder_type,
+ public StrongShapeImpl<N + 1, As...>
+{
+ using dimension_holder_type = typename AxisTraits<A>::dimension_holder_type;
+
+ using rest_type = StrongShapeImpl<N + 1, As...>;
+
+ StrongShapeImpl() : dimension_holder_type{0}, rest_type() {}
+
+ template <typename... Ts>
+ explicit StrongShapeImpl(int32_t t, Ts... ts) : dimension_holder_type{t}, rest_type(ts...)
+ {
+ }
+
+ static constexpr Axis axis(int index) { return index == N ? A : rest_type::axis(index); }
+
+ static constexpr int index(Axis axis) { return axis == A ? N : rest_type::index(axis); }
+
+ static constexpr bool has(Axis axis) { return axis == A ? true : rest_type::has(axis); }
+
+ int32_t get(Axis axis) const
+ {
+ return axis == A ? dimension_holder_type::operator()() : rest_type::get(axis);
+ }
+
+ template <Axis B> int32_t get() const
+ {
+ return B == A ? dimension_holder_type::operator()() : rest_type::template get<B>();
+ }
+
+ int32_t get(int index) const
+ {
+ return index == N ? dimension_holder_type::operator()() : rest_type::get(index);
+ }
+
+ bool set(Axis axis, int32_t t)
+ {
+ if (axis == A)
+ {
+ dimension_holder_type::operator()(t);
+ return true;
+ }
+ return rest_type::set(axis, t);
+ }
+
+ bool set(int index, int32_t t)
+ {
+ if (index == N)
+ {
+ dimension_holder_type::operator()(t);
+ return true;
+ }
+ return rest_type::set(index, t);
+ }
+
+ template <Axis B> bool set(int32_t t)
+ {
+ if (A == B)
+ {
+ dimension_holder_type::operator()(t);
+ return true;
+ }
+ return rest_type::template set<B>(t);
+ }
+};
+
+template <Layout T> struct LayoutTraits;
+
+#define TFLITE_GPU_LAYOUT_TRAITS(LayoutName, ...) \
+ template <> struct LayoutTraits<Layout::LayoutName> \
+ { \
+ using strong_shape_type = StrongShapeImpl<0, __VA_ARGS__>; \
+ }
+
+TFLITE_GPU_LAYOUT_TRAITS(HW, Axis::HEIGHT, Axis::WIDTH);
+TFLITE_GPU_LAYOUT_TRAITS(HWD, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH);
+TFLITE_GPU_LAYOUT_TRAITS(OHWI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
+ Axis::INPUT_CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(OIHW, Axis::OUTPUT_CHANNELS, Axis::INPUT_CHANNELS, Axis::HEIGHT,
+ Axis::WIDTH);
+TFLITE_GPU_LAYOUT_TRAITS(IOHW, Axis::INPUT_CHANNELS, Axis::OUTPUT_CHANNELS, Axis::HEIGHT,
+ Axis::WIDTH);
+TFLITE_GPU_LAYOUT_TRAITS(IHWO, Axis::INPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH,
+ Axis::OUTPUT_CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(CHW, Axis::CHANNELS, Axis::HEIGHT, Axis::WIDTH);
+TFLITE_GPU_LAYOUT_TRAITS(HWC, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(HWDC, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH, Axis::CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(LINEAR, Axis::VALUE);
+TFLITE_GPU_LAYOUT_TRAITS(SCALAR, Axis::VALUE);
+TFLITE_GPU_LAYOUT_TRAITS(BHWC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(BHWDC, Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH,
+ Axis::CHANNELS);
+TFLITE_GPU_LAYOUT_TRAITS(OHWDI, Axis::OUTPUT_CHANNELS, Axis::HEIGHT, Axis::WIDTH, Axis::DEPTH,
+ Axis::INPUT_CHANNELS);
+
+#undef TFLITE_GPU_LAYOUT_TRAITS
+
+template <> struct LayoutTraits<Layout::UNKNOWN>
+{
+ using strong_shape_type = StrongShapeImpl<0>;
+};
+
+template <Axis A> struct DimensionGetterFixedAxisFunc
+{
+ template <Layout T> int32_t operator()() const
+ {
+ constexpr int i = GetAxisIndex<T>(A);
+ return i >= 0 && i < l->dimensions.size() ? l->dimensions[i] : -1;
+ }
+ const Shape *l;
+};
+
+struct DimensionGetterFunc
+{
+ template <Layout T> int32_t operator()() const
+ {
+ uint32_t i = GetAxisIndex<T>(axis);
+ return i < l->dimensions.size() ? l->dimensions[i] : -1;
+ }
+ Axis axis;
+ const Shape *l;
+};
+
+template <Axis A> struct DimensionSetterFixedAxisFunc
+{
+ template <Layout T> bool operator()() const
+ {
+ constexpr uint32_t i = GetAxisIndex<T>(A);
+ if (i < l->dimensions.size())
+ {
+ l->dimensions[i] = v;
+ return true;
+ }
+ return false;
+ }
+ Shape *l;
+ int32_t v;
+};
+
+struct DimensionSetterFunc
+{
+ template <Layout T> bool operator()() const
+ {
+ uint32_t i = GetAxisIndex<T>(axis);
+ if (i < l->dimensions.size())
+ {
+ l->dimensions[i] = v;
+ return true;
+ }
+ return false;
+ }
+ Axis axis;
+ Shape *l;
+ int32_t v;
+};
+
+template <Layout L> struct ToShapeFunc
+{
+ template <Layout T> bool operator()() const
+ {
+ for (int i = 0; i < StrongShape<L>::size(); ++i)
+ {
+ int index = GetAxisIndex<T>(StrongShape<L>::axis(i));
+ if (index < 0)
+ return false;
+ shape->set(i, l.dimensions[index]);
+ }
+ return true;
+ }
+
+ StrongShape<L> *shape;
+ const Shape &l;
+};
+
+} // namespace internal_shape
+
+// template <Axis... As>
+template <Layout L> struct StrongShape : public internal_shape::LayoutTraits<L>::strong_shape_type
+{
+ using strong_shape_type = typename internal_shape::LayoutTraits<L>::strong_shape_type;
+ StrongShape() = default;
+
+ template <typename... Ts> explicit StrongShape(Ts... t) : strong_shape_type(t...) {}
+
+ constexpr static Layout layout = L;
+
+ bool operator==(const StrongShape<L> &shape) const
+ {
+ // TODO(akulik): implement better alternative.
+ return this->ToShape() == shape.ToShape();
+ }
+
+ bool operator!=(const StrongShape<L> &shape) const
+ {
+ // TODO(akulik): implement better alternative.
+ return this->ToShape() != shape.ToShape();
+ }
+ bool empty() const { return DimensionsProduct() == 0; }
+
+ // Turns StrongShape into generic shape.
+ Shape ToShape() const
+ {
+ std::vector<int32_t> dimensions(StrongShape::size());
+ for (int i = 0; i < StrongShape::size(); ++i)
+ {
+ dimensions[i] = StrongShape::get(i);
+ }
+ return Shape(L, std::move(dimensions));
+ }
+
+ // @return all dimensions multiplied
+ int64_t DimensionsProduct() const
+ {
+ int64_t product = 1;
+ for (int i = 0; i < StrongShape::size(); ++i)
+ {
+ product *= StrongShape::get(i);
+ }
+ return product;
+ }
+
+ // Translates given coordinates of the layout into a linear index assuming
+ // dimensions are sorted in tensor access order e.g. if you access
+ // foobar[i][j][k] order of coordinates should be i,j,k.
+ int64_t LinearIndex(const std::array<int32_t, StrongShape::size()> &coordinates) const
+ {
+ int64_t index = coordinates[0];
+ for (int i = 1; i < StrongShape::size(); ++i)
+ {
+ index = index * StrongShape::get(i) + coordinates[i];
+ }
+ return index;
+ }
+
+ // Copies all dimensions from the given generic shape into specific shape.
+ // It requires shape to have all axis defined in the given
+ // StrongShape. For example:
+ // - If this shape is OHWI but given shape is OIHW, Adopt will copy all
+ // dimensions and return true.
+ // - If this shape is OIHW but input shape is HW, Adopt will copy H and W
+ // dimensions and return true, but if this shape is HW and given shape
+ // OIHW, then Adopt will return false because not all axis are present in
+ // the input shape.
+ //
+ // @return false if generic shape is not compatible.
+ bool Adopt(const Shape &shape)
+ {
+ return DispatchByLayout(shape.layout, internal_shape::ToShapeFunc<L>{this, shape});
+ }
+
+ // For all axis defined in a given shape copies values to this shape.
+ // Therefore, it is possible to copy dimensions from CHW to BCHW, but not
+ // the other way around.
+ //
+ // BCHW bchw;
+ // CHW chw;
+ // bchw.CopyAllGivenAxis(chw); --> true
+ // chw.CopyAllGivenAxis(bchw); --> false
+ //
+ // @return false if axis in source shape is not defined here, thus value
+ // was not copied.
+ template <Layout B> bool CopyAllGivenAxis(const StrongShape<B> &source)
+ {
+ for (int i = 0; i < source.size(); ++i)
+ {
+ if (!StrongShape::set(source.axis(i), source.get(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // For all axis defined in this shape copies values from the given shape.
+ //
+ // BCHW bchw;
+ // CHW chw;
+ // bchw.CopyAllDefinedAxis(chw); --> false
+ // chw.CopyAllDefinedAxis(bchw); --> true
+ //
+ // @return false if given shape does not have axis defined here,
+ // therefore a value was not copied.
+ template <Layout B> bool CopyAllDefinedAxis(const StrongShape<B> &source)
+ {
+ for (int i = 0; i < StrongShape::size(); ++i)
+ {
+ int source_index = source.index(StrongShape::axis(i));
+ if (source_index < 0)
+ {
+ return false;
+ }
+ StrongShape::set(i, source.get(source_index)); // always true
+ }
+ return true;
+ }
+
+ // Copies values only for matching axis.
+ template <Layout B> void CopyMatchingAxis(const StrongShape<B> &source)
+ {
+ for (int i = 0; i < StrongShape::size(); ++i)
+ {
+ StrongShape::set(source.axis(i), source.get(i));
+ }
+ }
+
+ // AbslHash function for using in flat hash containers.
+ template <typename H> friend H AbslHashValue(H hash_state, const StrongShape &strong_shape)
+ {
+ for (size_t i = 0; i < strong_shape.size(); ++i)
+ {
+ hash_state = H::combine(std::move(hash_state), strong_shape.get(i));
+ }
+ return hash_state;
+ }
+};
+
+template <Layout T> inline std::string ToString(const StrongShape<T> &s)
+{
+ return ToString(s.ToShape());
+}
+
+template <Layout L> constexpr Layout StrongShape<L>::layout;
+
+template <class F>
+auto DispatchByLayout(Layout type, F f) -> decltype(f.template operator()<Layout::UNKNOWN>())
+{
+ switch (type)
+ {
+ case Layout::HW:
+ return f.template operator()<Layout::HW>();
+ case Layout::HWD:
+ return f.template operator()<Layout::HWD>();
+ case Layout::HWC:
+ return f.template operator()<Layout::HWC>();
+ case Layout::HWDC:
+ return f.template operator()<Layout::HWDC>();
+ case Layout::CHW:
+ return f.template operator()<Layout::CHW>();
+ case Layout::OIHW:
+ return f.template operator()<Layout::OIHW>();
+ case Layout::IOHW:
+ return f.template operator()<Layout::IOHW>();
+ case Layout::OHWI:
+ return f.template operator()<Layout::OHWI>();
+ case Layout::IHWO:
+ return f.template operator()<Layout::IHWO>();
+ case Layout::LINEAR:
+ return f.template operator()<Layout::LINEAR>();
+ case Layout::SCALAR:
+ return f.template operator()<Layout::SCALAR>();
+ case Layout::BHWC:
+ return f.template operator()<Layout::BHWC>();
+ case Layout::BHWDC:
+ return f.template operator()<Layout::BHWDC>();
+ case Layout::OHWDI:
+ return f.template operator()<Layout::OHWDI>();
+ case Layout::UNKNOWN:
+ return f.template operator()<Layout::UNKNOWN>();
+ }
+ return f.template operator()<Layout::UNKNOWN>();
+}
+
+template <Layout T> constexpr int Size() { return StrongShape<T>::size(); }
+
+template <Layout T> constexpr Axis GetAxis(int index) { return StrongShape<T>::axis(index); }
+
+template <Layout T> constexpr int GetAxisIndex(Axis axis) { return StrongShape<T>::index(axis); }
+
+template <Layout T> constexpr bool HasAxis(Axis axis) { return StrongShape<T>::has(axis); }
+
+template <Axis D> inline int32_t Shape::get() const
+{
+ return DispatchByLayout(layout, internal_shape::DimensionGetterFixedAxisFunc<D>{this});
+}
+
+inline int32_t Shape::get(Axis axis) const
+{
+ return DispatchByLayout(layout, internal_shape::DimensionGetterFunc{axis, this});
+}
+
+template <Axis D> inline bool Shape::set(int32_t t)
+{
+ return DispatchByLayout(layout, internal_shape::DimensionSetterFixedAxisFunc<D>{this, t});
+}
+
+inline bool Shape::set(Axis axis, int32_t t)
+{
+ return DispatchByLayout(layout, internal_shape::DimensionSetterFunc{axis, this, t});
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_SHAPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPEN_CL_SPI_H__
+#define __ONERT_BACKEND_GPU_CL_OPEN_CL_SPI_H__
+
+#include <cstdint>
+
+#include "Api.h"
+#include "AccessType.h"
+#include "Status.h"
+
+// Contains only service provider-related interfaces. Users should not use them
+// directly.
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// Converts a tensor object into another one.
+class TensorObjectConverter
+{
+public:
+ virtual ~TensorObjectConverter() = default;
+
+ virtual absl::Status Convert(const TensorObject &input, const TensorObject &output) = 0;
+};
+
+class TensorObjectConverterBuilder
+{
+public:
+ virtual ~TensorObjectConverterBuilder() = default;
+
+ virtual bool IsSupported(const TensorObjectDef &input, const TensorObjectDef &output) const = 0;
+
+ virtual absl::Status MakeConverter(const TensorObjectDef &input, const TensorObjectDef &output,
+ std::unique_ptr<TensorObjectConverter> *converter) = 0;
+};
+
+// Connects tensor definition provided by a user (external) with tensor
+// definition used by the inference engine (internal).
+struct TensorTieDef
+{
+ uint32_t id;
+ AccessType access_type;
+ TensorObjectDef internal_def;
+ TensorObjectDef external_def;
+};
+
+// Connects external tensor object to internal tensor object and provides
+// functionality to copy data to/from external object to internal.
+class TensorTie
+{
+public:
+ explicit TensorTie(const TensorTieDef &def) : def_(def) {}
+
+ virtual ~TensorTie() = default;
+
+ virtual absl::Status SetExternalObject(TensorObject obj) = 0;
+
+ virtual TensorObject GetExternalObject() = 0;
+
+ virtual absl::Status CopyToExternalObject() = 0;
+
+ virtual absl::Status CopyFromExternalObject() = 0;
+
+ const TensorTieDef &def() const { return def_; }
+
+private:
+ const TensorTieDef def_;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPEN_CL_SPI_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_STATUS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_STATUS_H__
+
+#include "absl/status/status.h" // IWYU pragma: export
+#define RETURN_IF_ERROR(s) \
+ { \
+ auto c = (s); \
+ if (!c.ok()) \
+ return c; \
+ } // IWYU pragma: export
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_STATUS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "StorageTypeUtil.h"
+
+#include "TensorType.h"
+#include "DataType.h"
+#include "Shape.h"
+#include "Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+bool CanCreateTensorWithShape(const DeviceInfo &device_info, const BHWDC &shape,
+ const TensorDescriptor &descriptor)
+{
+ const int slices = DivideRoundUp(shape.c, 4);
+ switch (descriptor.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ {
+ const uint64_t flt4_size = 4 * (descriptor.data_type == DataType::FLOAT32 ? 4 : 2);
+ const uint64_t buffer_size = shape.b * shape.w * shape.h * shape.d * slices * flt4_size;
+ return buffer_size <= device_info.buffer_max_size;
+ }
+ case TensorStorageType::IMAGE_BUFFER:
+ return (uint64_t)shape.b * shape.w * shape.h * shape.d * slices <=
+ device_info.image_buffer_max_size;
+ case TensorStorageType::TEXTURE_3D:
+ if (device_info.cl_version < OpenCLVersion::CL_1_2 && slices == 1)
+ {
+ // clCreateImage3D (that used in CL 1.0/1.1) can not create image with
+ // depth = 1 by specification;
+ return false;
+ }
+ return (uint64_t)shape.w * shape.b <= device_info.image3d_max_width &&
+ (uint64_t)shape.h <= device_info.image3d_max_height &&
+ (uint64_t)slices * shape.d <= device_info.image3d_max_depth;
+ case TensorStorageType::TEXTURE_ARRAY:
+ // Bug on some Adreno. b/131099086
+ if (slices == 1 && !device_info.SupportsOneLayerTextureArray())
+ {
+ return false;
+ }
+ return (uint64_t)shape.w * shape.b <= device_info.image2d_max_width &&
+ (uint64_t)shape.h <= device_info.image2d_max_height &&
+ (uint64_t)slices * shape.d <= device_info.image_array_max_layers;
+ case TensorStorageType::TEXTURE_2D:
+ return (uint64_t)shape.w * shape.b * shape.d <= device_info.image2d_max_width &&
+ (uint64_t)shape.h * slices <= device_info.image2d_max_height;
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return (uint64_t)shape.c <= 4 &&
+ device_info.SupportsFloatImage2D(descriptor.data_type, shape.c) &&
+ (uint64_t)shape.w * shape.b * shape.d <= device_info.image2d_max_width &&
+ (uint64_t)shape.h <= device_info.image2d_max_height;
+ default:
+ return false;
+ }
+}
+
+bool CanCreateTensorWithShape(const DeviceInfo &device_info, const BHWC &shape,
+ const TensorDescriptor &descriptor)
+{
+ const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+ return CanCreateTensorWithShape(device_info, shape5D, descriptor);
+}
+
+TensorStorageType SelectBestStorageType(const DeviceInfo &device_info, const BHWC &shape,
+ const TensorStorageType &desired, const DataType &data_type,
+ const Layout &layout)
+{
+ if (CanCreateTensorWithShape(device_info, shape, TensorDescriptor{data_type, desired, layout}))
+ {
+ return desired;
+ }
+ auto GetBestTypeAfterTextureArray = [&]() {
+ if (device_info.SupportsImageBuffer() &&
+ CanCreateTensorWithShape(
+ device_info, shape, TensorDescriptor{data_type, TensorStorageType::IMAGE_BUFFER, layout}))
+ {
+ return TensorStorageType::IMAGE_BUFFER;
+ }
+ else
+ {
+ return TensorStorageType::BUFFER;
+ }
+ };
+ auto GetBestTypeAfterTexture2D = [&]() {
+ if (device_info.SupportsTextureArray() &&
+ CanCreateTensorWithShape(
+ device_info, shape,
+ TensorDescriptor{data_type, TensorStorageType::TEXTURE_ARRAY, layout}))
+ {
+ return TensorStorageType::TEXTURE_ARRAY;
+ }
+ else
+ {
+ return GetBestTypeAfterTextureArray();
+ }
+ };
+ auto GetBestTypeAfterTexture3D = [&]() {
+ if (CanCreateTensorWithShape(
+ device_info, shape, TensorDescriptor{data_type, TensorStorageType::TEXTURE_2D, layout}))
+ {
+ return TensorStorageType::TEXTURE_2D;
+ }
+ else
+ {
+ return GetBestTypeAfterTexture2D();
+ }
+ };
+ switch (desired)
+ {
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return GetBestTypeAfterTexture2D();
+ case TensorStorageType::TEXTURE_ARRAY:
+ return GetBestTypeAfterTextureArray();
+ case TensorStorageType::TEXTURE_3D:
+ return GetBestTypeAfterTexture3D();
+ case TensorStorageType::IMAGE_BUFFER:
+ case TensorStorageType::BUFFER:
+ return TensorStorageType::BUFFER;
+ default:
+ return TensorStorageType::BUFFER;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_STORAGE_TYPE_UTIL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_STORAGE_TYPE_UTIL_H__
+
+#include "DeviceInfo.h"
+#include "TensorType.h"
+#include "DataType.h"
+#include "Shape.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+bool CanCreateTensorWithShape(const DeviceInfo &device_info, const BHWDC &shape,
+ const TensorDescriptor &descriptor);
+
+bool CanCreateTensorWithShape(const DeviceInfo &device_info, const BHWC &shape,
+ const TensorDescriptor &descriptor);
+
+TensorStorageType SelectBestStorageType(const DeviceInfo &device_info, const BHWC &shape,
+ const TensorStorageType &desired, const DataType &data_type,
+ const Layout &layout);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_STORAGE_TYPE_UTIL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Tensor.h"
+
+#include <cstring>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+
+#include "Buffer.h"
+#include "ClImageFormat.h"
+#include "ClMemory.h"
+#include "GpuObject.h"
+#include "TensorType.h"
+#include "InternalTensor.h"
+#include "DataType.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+absl::Status AllocateTensorMemory(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, const void *data_ptr,
+ CLMemory *result)
+{
+ const int slices = DivideRoundUp(shape.c, 4);
+ cl_mem_flags mem_flags = CL_MEM_READ_WRITE;
+ if (data_ptr)
+ {
+ mem_flags |= CL_MEM_COPY_HOST_PTR;
+ }
+ switch (descriptor.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ {
+ const size_t data_size =
+ shape.b * shape.w * shape.h * shape.d * slices * 4 * SizeOf(descriptor.data_type);
+ cl_int error_code;
+ cl_mem memory = clCreateBuffer(context.context(), mem_flags, data_size,
+ const_cast<void *>(data_ptr), &error_code);
+ if (!memory)
+ {
+ return absl::UnknownError(absl::StrCat(
+ "Failed to allocate device memory (clCreateBuffer): ", CLErrorCodeToString(error_code)));
+ }
+ *result = CLMemory(memory, true);
+ return absl::OkStatus();
+ }
+ case TensorStorageType::TEXTURE_2D:
+ {
+ cl_image_desc desc;
+ desc.image_type = CL_MEM_OBJECT_IMAGE2D;
+ desc.image_width = shape.w * shape.b * shape.d;
+ desc.image_height = shape.h * slices;
+ desc.image_depth = 0;
+ desc.image_row_pitch = 0;
+ desc.image_slice_pitch = 0;
+ desc.num_mip_levels = 0;
+ desc.num_samples = 0;
+ desc.buffer = nullptr;
+
+ cl_image_format format;
+ format.image_channel_order = CL_RGBA;
+ format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
+
+ cl_int error_code;
+ cl_mem memory = CreateImage2DLegacy(context.context(), mem_flags, &format, &desc,
+ const_cast<void *>(data_ptr), &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to create 2D texture (clCreateImage): ",
+ CLErrorCodeToString(error_code)));
+ }
+
+ *result = CLMemory(memory, true);
+ return absl::OkStatus();
+ }
+ case TensorStorageType::TEXTURE_3D:
+ {
+ cl_image_desc desc;
+ desc.image_type = CL_MEM_OBJECT_IMAGE3D;
+ desc.image_width = shape.w * shape.b;
+ desc.image_height = shape.h;
+ desc.image_depth = slices * shape.d;
+ desc.image_row_pitch = 0;
+ desc.image_slice_pitch = 0;
+ desc.num_mip_levels = 0;
+ desc.num_samples = 0;
+ desc.buffer = nullptr;
+
+ cl_image_format format;
+ format.image_channel_order = CL_RGBA;
+ format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
+
+ cl_int error_code;
+ cl_mem memory = CreateImage3DLegacy(context.context(), mem_flags, &format, &desc,
+ const_cast<void *>(data_ptr), &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to create 3D texture (clCreateImage): ",
+ CLErrorCodeToString(error_code)));
+ }
+
+ *result = CLMemory(memory, true);
+ return absl::OkStatus();
+ }
+ case TensorStorageType::TEXTURE_ARRAY:
+ {
+ cl_image_desc desc;
+ desc.image_type = CL_MEM_OBJECT_IMAGE2D_ARRAY;
+ desc.image_width = shape.w * shape.b;
+ desc.image_height = shape.h;
+ desc.image_depth = 0;
+ desc.image_array_size = slices * shape.d;
+ desc.image_row_pitch = 0;
+ desc.image_slice_pitch = 0;
+ desc.num_mip_levels = 0;
+ desc.num_samples = 0;
+ desc.buffer = nullptr;
+
+ cl_image_format format;
+ format.image_channel_order = CL_RGBA;
+ format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
+
+ cl_int error_code;
+ cl_mem memory = clCreateImage(context.context(), mem_flags, &format, &desc,
+ const_cast<void *>(data_ptr), &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat(
+ "Failed to create 2D texture array (clCreateImage): ", CLErrorCodeToString(error_code)));
+ }
+
+ *result = CLMemory(memory, true);
+ return absl::OkStatus();
+ }
+
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ {
+ if (slices != 1)
+ {
+ return absl::InvalidArgumentError(absl::StrCat(
+ "SINGLE_TEXTURE_2D support only channels in range [1-4], but ", shape.c, "was provided"));
+ }
+ cl_image_desc desc;
+ desc.image_type = CL_MEM_OBJECT_IMAGE2D;
+ desc.image_width = shape.w * shape.b * shape.d;
+ desc.image_height = shape.h;
+ desc.image_depth = 0;
+ desc.image_row_pitch = 0;
+ desc.image_slice_pitch = 0;
+ desc.num_mip_levels = 0;
+ desc.num_samples = 0;
+ desc.buffer = nullptr;
+
+ cl_image_format format;
+ if (context.IsFloatTexture2DSupported(shape.c, descriptor.data_type))
+ {
+ format.image_channel_order = ToChannelOrder(shape.c);
+ format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
+ }
+ else
+ {
+ return absl::InvalidArgumentError(
+ absl::StrCat("This device doesn't support ", shape.c, "-channel textures."));
+ }
+
+ cl_int error_code;
+ cl_mem memory = CreateImage2DLegacy(context.context(), mem_flags, &format, &desc,
+ const_cast<void *>(data_ptr), &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat(
+ "Failed to create single 2D texture (clCreateImage): ", CLErrorCodeToString(error_code)));
+ }
+
+ *result = CLMemory(memory, true);
+ return absl::OkStatus();
+ }
+
+ default:
+ return absl::InternalError("Unsupported tensor storage type");
+ }
+}
+
+absl::Status CreateImageBufferFromBuffer(const CLContext &context, cl_mem memory,
+ DataType data_type, int width, cl_mem *result)
+{
+ cl_image_format format;
+ cl_image_desc desc;
+ std::memset(&desc, 0, sizeof(desc));
+ desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+ desc.image_width = width;
+ desc.mem_object = memory;
+
+ format.image_channel_data_type = ToImageChannelType(data_type);
+ format.image_channel_order = CL_RGBA;
+
+ cl_int error_code;
+ *result =
+ clCreateImage(context.context(), CL_MEM_READ_WRITE, &format, &desc, nullptr, &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to create Image from Buffer (clCreateImage): ",
+ CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CreateTensor(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, cl_mem memory, Tensor *result)
+{
+ const bool memory_owner = memory == nullptr;
+ if (memory_owner)
+ {
+ CLMemory mem;
+ RETURN_IF_ERROR(AllocateTensorMemory(context, shape, descriptor, nullptr, &mem));
+ memory = mem.Release();
+ }
+ if (descriptor.storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ cl_mem image_memory;
+ RETURN_IF_ERROR(CreateImageBufferFromBuffer(
+ context, memory, descriptor.data_type,
+ shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4), &image_memory));
+ *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
+ }
+ else
+ {
+ *result = Tensor(memory, memory_owner, shape, descriptor);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CreateTensorShared(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, cl_mem memory, Tensor *result)
+{
+ const bool memory_owner = false;
+ if (descriptor.storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ cl_mem image_memory;
+ RETURN_IF_ERROR(CreateImageBufferFromBuffer(
+ context, memory, descriptor.data_type,
+ shape.b * shape.w * shape.h * shape.d * DivideRoundUp(shape.c, 4), &image_memory));
+ *result = Tensor(memory, memory_owner, image_memory, shape, descriptor);
+ }
+ else
+ {
+ *result = Tensor(memory, memory_owner, shape, descriptor);
+ }
+ return absl::OkStatus();
+}
+
+} // namespace
+
+absl::Status TensorDescriptor::CreateGPUObject(CLContext *context, GPUObjectPtr *result) const
+{
+ Tensor gpu_tensor;
+ RETURN_IF_ERROR(gpu_tensor.CreateFromDescriptor(*this, context));
+ *result = absl::make_unique<Tensor>(std::move(gpu_tensor));
+ return absl::OkStatus();
+}
+
+Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWC &shape,
+ const TensorDescriptor &descriptor)
+ : memory_(memory), image_buffer_memory_(nullptr), memory_owner_(memory_owner),
+ shape_(shape.b, shape.h, shape.w, 1, shape.c), descriptor_(descriptor)
+{
+}
+
+Tensor::Tensor(cl_mem memory, bool memory_owner, const BHWDC &shape,
+ const TensorDescriptor &descriptor)
+ : memory_(memory), image_buffer_memory_(nullptr), memory_owner_(memory_owner), shape_(shape),
+ descriptor_(descriptor)
+{
+}
+
+Tensor::Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory, const BHWC &shape,
+ const TensorDescriptor &descriptor)
+ : memory_(memory), image_buffer_memory_(image_buffer_memory), memory_owner_(memory_owner),
+ shape_(shape.b, shape.h, shape.w, 1, shape.c), descriptor_(descriptor)
+{
+}
+
+Tensor::Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory, const BHWDC &shape,
+ const TensorDescriptor &descriptor)
+ : memory_(memory), image_buffer_memory_(image_buffer_memory), memory_owner_(memory_owner),
+ shape_(shape), descriptor_(descriptor)
+{
+}
+
+Tensor::Tensor(Tensor &&tensor)
+ : memory_(tensor.memory_), image_buffer_memory_(tensor.image_buffer_memory_),
+ memory_owner_(tensor.memory_owner_), shape_(tensor.shape_), descriptor_(tensor.descriptor_)
+{
+ tensor.memory_ = nullptr;
+ tensor.image_buffer_memory_ = nullptr;
+}
+
+Tensor &Tensor::operator=(Tensor &&tensor)
+{
+ if (this != &tensor)
+ {
+ Release();
+ std::swap(memory_, tensor.memory_);
+ std::swap(image_buffer_memory_, tensor.image_buffer_memory_);
+ std::swap(memory_owner_, tensor.memory_owner_);
+ std::swap(shape_, tensor.shape_);
+ std::swap(descriptor_, tensor.descriptor_);
+ }
+ return *this;
+}
+
+void Tensor::Release()
+{
+ // image_buffer_memory_ always owned by object
+ if (image_buffer_memory_)
+ {
+ clReleaseMemObject(image_buffer_memory_);
+ image_buffer_memory_ = nullptr;
+ }
+ if (memory_owner_ && memory_)
+ {
+ clReleaseMemObject(memory_);
+ memory_ = nullptr;
+ }
+}
+
+absl::Status Tensor::GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const
+{
+ const auto *buffer_desc = dynamic_cast<const BufferDescriptor *>(obj_ptr);
+ if (buffer_desc)
+ {
+ if (descriptor_.storage_type != TensorStorageType::BUFFER)
+ {
+ return absl::InvalidArgumentError("Tensor can be used with BufferDescriptor only wtih "
+ "TensorStorageType::BUFFER.");
+ }
+ resources->buffers.push_back({"buffer", memory_});
+ return absl::OkStatus();
+ }
+ const auto *tensor_desc = dynamic_cast<const TensorDescriptor *>(obj_ptr);
+ if (!tensor_desc)
+ {
+ return absl::InvalidArgumentError("Expected TensorDescriptor on input.");
+ }
+ if (descriptor_.HasAxis(Axis::WIDTH))
+ {
+ resources->ints.push_back({"width", Width()});
+ resources->ints.push_back({"width_div2", Width() / 2});
+ resources->ints.push_back({"width_div4", Width() / 4});
+ resources->ints.push_back({"width_batched", Width() * Batch()});
+ resources->ints.push_back({"width_batched_div2", Width() * Batch() / 2});
+ resources->ints.push_back({"width_batched_div4", Width() * Batch() / 4});
+ }
+ if (descriptor_.HasAxis(Axis::HEIGHT))
+ {
+ resources->ints.push_back({"height", Height()});
+ }
+ if (descriptor_.HasAxis(Axis::CHANNELS))
+ {
+ resources->ints.push_back({"slices", Slices()});
+ resources->ints.push_back({"channels", Channels()});
+ }
+ if (descriptor_.HasAxis(Axis::BATCH))
+ {
+ resources->ints.push_back({"batch", Batch()});
+ }
+ if (descriptor_.HasAxis(Axis::DEPTH))
+ {
+ resources->ints.push_back({"depth", Depth()});
+ }
+
+ if (descriptor_.storage_type == TensorStorageType::BUFFER)
+ {
+ resources->buffers.push_back({"buffer", memory_});
+ }
+ else if (descriptor_.storage_type == TensorStorageType::TEXTURE_2D ||
+ descriptor_.storage_type == TensorStorageType::SINGLE_TEXTURE_2D)
+ {
+ resources->images2d.push_back({"image2d", memory_});
+ }
+ else if (descriptor_.storage_type == TensorStorageType::TEXTURE_ARRAY)
+ {
+ resources->image2d_arrays.push_back({"image2d_array", memory_});
+ }
+ else if (descriptor_.storage_type == TensorStorageType::TEXTURE_3D)
+ {
+ resources->images3d.push_back({"image3d", memory_});
+ }
+ else if (descriptor_.storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ if (obj_ptr->GetAccess() == AccessType::READ)
+ {
+ resources->image_buffers.push_back({"image_buffer", image_buffer_memory_});
+ }
+ else
+ {
+ resources->buffers.push_back({"buffer", memory_});
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+int3 Tensor::GetFullTensorRegion() const
+{
+ switch (descriptor_.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::IMAGE_BUFFER:
+ return {shape_.w * shape_.b, shape_.h, shape_.d * Slices()};
+ case TensorStorageType::TEXTURE_2D:
+ return {shape_.w * shape_.b * shape_.d, shape_.h * Slices(), 1};
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return {shape_.w * shape_.b * shape_.d, shape_.h, 1};
+ case TensorStorageType::UNKNOWN:
+ return {-1, -1, -1};
+ }
+ return {-1, -1, -1};
+}
+
+absl::Status Tensor::IsValid(const BHWC &shape) const
+{
+ if (shape.b != shape_.b)
+ {
+ return absl::InvalidArgumentError("Shape batch does not match tensor batch");
+ }
+ if (shape.w != shape_.w)
+ {
+ return absl::InvalidArgumentError("Shape width does not match tensor width");
+ }
+ if (shape.h != shape_.h)
+ {
+ return absl::InvalidArgumentError("Shape height does not match tensor height");
+ }
+ if (shape.c != shape_.c)
+ {
+ return absl::InvalidArgumentError("Shape channels does not match tensor channels");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status Tensor::IsValid(const BHWDC &shape) const
+{
+ if (shape.b != shape_.b)
+ {
+ return absl::InvalidArgumentError("Shape batch does not match tensor batch");
+ }
+ if (shape.w != shape_.w)
+ {
+ return absl::InvalidArgumentError("Shape width does not match tensor width");
+ }
+ if (shape.h != shape_.h)
+ {
+ return absl::InvalidArgumentError("Shape height does not match tensor height");
+ }
+ if (shape.d != shape_.d)
+ {
+ return absl::InvalidArgumentError("Shape depth does not match tensor depth");
+ }
+ if (shape.c != shape_.c)
+ {
+ return absl::InvalidArgumentError("Shape channels does not match tensor channels");
+ }
+ return absl::OkStatus();
+}
+
+int Tensor::GetAlignedChannels() const
+{
+ return descriptor_.storage_type == TensorStorageType::SINGLE_TEXTURE_2D ? shape_.c
+ : AlignByN(shape_.c, 4);
+}
+
+uint64_t Tensor::GetMemorySizeInBytes() const
+{
+ const uint64_t flt_size = static_cast<uint64_t>(SizeOf(descriptor_.data_type));
+ const uint64_t flt4_size = 4 * flt_size;
+ switch (descriptor_.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ return flt4_size * shape_.b * shape_.w * shape_.h * shape_.d * Slices();
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return flt_size * shape_.w * shape_.h * shape_.c * shape_.b * shape_.d;
+ default:
+ return 0;
+ }
+}
+
+cl_mem Tensor::GetMemoryPtr() const
+{
+ return descriptor_.storage_type == TensorStorageType::IMAGE_BUFFER ? image_buffer_memory_
+ : memory_;
+}
+
+cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; }
+
+absl::Status Tensor::WriteDataBHWDC(absl::Span<const float> in, CLCommandQueue *queue)
+{
+ void *data_ptr = nullptr;
+ const int aligned_channels = GetAlignedChannels();
+ const int elements_count = shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
+
+ const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
+ std::vector<float> data_f;
+ data_f.resize(elements_count);
+ data_ptr = data_f.data();
+ DataFromBHWDC(in, shape_, descriptor_, absl::MakeSpan(data_f.data(), data_f.size()));
+
+ switch (descriptor_.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ RETURN_IF_ERROR(queue->EnqueueWriteBuffer(memory_, data_size, data_ptr));
+ break;
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ RETURN_IF_ERROR(queue->EnqueueWriteImage(memory_, GetFullTensorRegion(), data_ptr));
+ break;
+ default:
+ return absl::InternalError("Unsupported tensor storage type");
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status Tensor::WriteData(CLCommandQueue *queue, const TensorFloat32 &src)
+{
+ RETURN_IF_ERROR(IsValid(src.shape));
+ return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+}
+
+absl::Status Tensor::WriteData(CLCommandQueue *queue,
+ const InternalTensor<Linear, DataType::FLOAT32> &src)
+{
+ return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+}
+
+absl::Status Tensor::WriteData(CLCommandQueue *queue,
+ const InternalTensor<HWC, DataType::FLOAT32> &src)
+{
+ return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+}
+
+absl::Status Tensor::WriteData(CLCommandQueue *queue, const Tensor5DFloat32 &src)
+{
+ RETURN_IF_ERROR(IsValid(src.shape));
+ return WriteDataBHWDC(absl::MakeConstSpan(src.data), queue);
+}
+
+absl::Status Tensor::ReadDataBHWDC(absl::Span<float> out, CLCommandQueue *queue) const
+{
+ void *data_ptr = nullptr;
+ const int aligned_channels = GetAlignedChannels();
+ const int elements_count = shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
+ const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
+
+ std::vector<float> data_f;
+ data_f.resize(elements_count);
+ data_ptr = data_f.data();
+ switch (descriptor_.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ RETURN_IF_ERROR(queue->EnqueueReadBuffer(memory_, data_size, data_ptr));
+ break;
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ RETURN_IF_ERROR(queue->EnqueueReadImage(memory_, GetFullTensorRegion(), data_ptr));
+ break;
+ default:
+ return absl::InternalError("Unsupported tensor storage type");
+ }
+
+ if (descriptor_.data_type == DataType::FLOAT32)
+ {
+ DataToBHWDC(absl::MakeConstSpan(data_f.data(), data_f.size()), shape_, descriptor_, out);
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status Tensor::ReadData(CLCommandQueue *queue, TensorFloat32 *dst) const
+{
+ RETURN_IF_ERROR(IsValid(dst->shape));
+ return ReadDataBHWDC(absl::MakeSpan(dst->data), queue);
+}
+
+absl::Status Tensor::ReadData(CLCommandQueue *queue, Tensor5DFloat32 *dst) const
+{
+ RETURN_IF_ERROR(IsValid(dst->shape));
+ return ReadDataBHWDC(absl::MakeSpan(dst->data), queue);
+}
+
+absl::Status Tensor::CreateFromDescriptor(const TensorDescriptor &desc, CLContext *context)
+{
+ shape_ = desc.shape;
+ descriptor_.data_type = desc.data_type;
+ descriptor_.storage_type = desc.storage_type;
+ descriptor_.layout = desc.layout;
+ memory_owner_ = true;
+ CLMemory memory;
+ uint8_t *data_ptr = desc.data.empty() ? nullptr : const_cast<unsigned char *>(desc.data.data());
+ RETURN_IF_ERROR(AllocateTensorMemory(*context, shape_, descriptor_, data_ptr, &memory));
+ memory_ = memory.Release();
+ if (desc.storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ RETURN_IF_ERROR(CreateImageBufferFromBuffer(*context, memory_, desc.data_type,
+ shape_.b * shape_.w * shape_.h * shape_.d *
+ DivideRoundUp(shape_.c, 4),
+ &image_buffer_memory_));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status CreateTensor(const CLContext &context, const BHWC &shape,
+ const TensorDescriptor &descriptor, Tensor *result)
+{
+ const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+ return CreateTensor(context, shape5D, descriptor, nullptr, result);
+}
+
+absl::Status CreateTensor(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, Tensor *result)
+{
+ return CreateTensor(context, shape, descriptor, nullptr, result);
+}
+
+absl::Status CreateSharedTensor(const CLContext &context, cl_mem memory, const BHWC &shape,
+ const TensorDescriptor &descriptor, Tensor *result)
+{
+ const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+ return CreateTensorShared(context, shape5D, descriptor, memory, result);
+}
+
+absl::Status CreateSharedTensor(const CLContext &context, cl_mem memory, const BHWDC &shape,
+ const TensorDescriptor &descriptor, Tensor *result)
+{
+ return CreateTensorShared(context, shape, descriptor, memory, result);
+}
+
+absl::Status AllocateTensorMemory(const CLContext &context, const BHWC &shape,
+ const TensorDescriptor &descriptor, CLMemory *result)
+{
+ const BHWDC shape5D(shape.b, shape.h, shape.w, 1, shape.c);
+ return AllocateTensorMemory(context, shape5D, descriptor, nullptr, result);
+}
+
+absl::Status AllocateTensorMemory(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, CLMemory *result)
+{
+ return AllocateTensorMemory(context, shape, descriptor, nullptr, result);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_H__
+
+#include <cstdint>
+#include <memory>
+
+#include "absl/types/span.h"
+#include "ClCommandQueue.h"
+#include "OpenclWrapper.h"
+#include "ClContext.h"
+#include "ClDevice.h"
+#include "ClMemory.h"
+#include "GpuObject.h"
+#include "TensorType.h"
+#include "Util.h"
+#include "DataType.h"
+#include "Shape.h"
+#include "Status.h"
+#include "InternalTensor.h"
+#include "Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class Tensor : public GPUObject
+{
+public:
+ Tensor() : memory_(nullptr), image_buffer_memory_(nullptr), memory_owner_(true) {}
+ Tensor(cl_mem memory, bool memory_owner, const BHWC &shape, const TensorDescriptor &descriptor);
+ Tensor(cl_mem memory, bool memory_owner, const BHWDC &shape, const TensorDescriptor &descriptor);
+ Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory, const BHWC &shape,
+ const TensorDescriptor &descriptor);
+ Tensor(cl_mem memory, bool memory_owner, cl_mem image_buffer_memory, const BHWDC &shape,
+ const TensorDescriptor &descriptor);
+
+ // Move only
+ Tensor(Tensor &&tensor);
+ Tensor &operator=(Tensor &&tensor);
+ Tensor(const Tensor &) = delete;
+ Tensor &operator=(const Tensor &) = delete;
+
+ virtual ~Tensor() { Release(); }
+
+ absl::Status GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const override;
+
+ int Width() const { return shape_.w; }
+ int Height() const { return shape_.h; }
+ int Depth() const { return shape_.d; }
+ int Channels() const { return shape_.c; }
+ int Slices() const { return DivideRoundUp(shape_.c, 4); }
+ int Batch() const { return shape_.b; }
+ TensorDescriptor GetDescriptor() const { return descriptor_; }
+ DataType GetDataType() const { return descriptor_.data_type; }
+ TensorStorageType GetStorageType() const { return descriptor_.storage_type; }
+
+ // for profiling and memory statistics
+ uint64_t GetMemorySizeInBytes() const;
+
+ cl_mem GetMemoryPtr() const;
+
+ // This function returns buffer memory ptr for IMAGE_BUFFER instead of image
+ // memory ptr.
+ cl_mem GetMemoryPtrForWriting() const;
+
+ absl::Status WriteData(CLCommandQueue *queue, const TensorFloat32 &src);
+ absl::Status WriteData(CLCommandQueue *queue,
+ const InternalTensor<Linear, DataType::FLOAT32> &src);
+ absl::Status WriteData(CLCommandQueue *queue, const InternalTensor<HWC, DataType::FLOAT32> &src);
+
+ absl::Status WriteData(CLCommandQueue *queue, const Tensor5DFloat32 &src);
+ absl::Status ReadData(CLCommandQueue *queue, TensorFloat32 *dst) const;
+ absl::Status ReadData(CLCommandQueue *queue, Tensor5DFloat32 *dst) const;
+
+ absl::Status CreateFromDescriptor(const TensorDescriptor &desc, CLContext *context);
+
+private:
+ absl::Status IsValid(const BHWC &shape) const;
+ absl::Status IsValid(const BHWDC &shape) const;
+
+ int GetChannelsAlignment() const;
+ int GetAlignedChannels() const;
+
+ absl::Status WriteDataBHWDC(absl::Span<const float> in, CLCommandQueue *queue);
+ absl::Status ReadDataBHWDC(absl::Span<float> out, CLCommandQueue *queue) const;
+
+ int3 GetFullTensorRegion() const;
+ void Release();
+
+ cl_mem memory_;
+ cl_mem image_buffer_memory_; // for TensorStorageType::IMAGE_BUFFER only
+ bool memory_owner_;
+ BHWDC shape_;
+ TensorDescriptor descriptor_;
+};
+
+using TensorPtr = std::shared_ptr<Tensor>;
+
+absl::Status AllocateTensorMemory(const CLContext &context, const BHWC &shape,
+ const TensorDescriptor &descriptor, CLMemory *result);
+
+absl::Status AllocateTensorMemory(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, CLMemory *result);
+
+absl::Status CreateTensor(const CLContext &context, const BHWC &shape,
+ const TensorDescriptor &descriptor, Tensor *result);
+
+absl::Status CreateTensor(const CLContext &context, const BHWDC &shape,
+ const TensorDescriptor &descriptor, Tensor *result);
+
+absl::Status CreateSharedTensor(const CLContext &context, cl_mem memory, const BHWC &shape,
+ const TensorDescriptor &descriptor, Tensor *result);
+
+absl::Status CreateSharedTensor(const CLContext &context, cl_mem memory, const BHWDC &shape,
+ const TensorDescriptor &descriptor, Tensor *result);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorType.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/substitute.h"
+#include "Shape.h"
+#include "DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::string GetWriteImageFromDataType(DataType data_type)
+{
+ if (data_type == DataType::FLOAT32)
+ {
+ return "write_imagef";
+ }
+ else if (data_type == DataType::FLOAT16)
+ {
+ return "write_imageh";
+ }
+ else
+ {
+ throw std::runtime_error("Not supported data type");
+ }
+}
+
+} // namespace
+
+std::string TextureAddressModeToString(TextureAddressMode address_mode)
+{
+ switch (address_mode)
+ {
+ case TextureAddressMode::DONT_CARE:
+ return "smp_none";
+ case TextureAddressMode::ZERO:
+ return "smp_zero";
+ }
+ return "";
+}
+
+std::string ToString(TensorStorageType type)
+{
+ switch (type)
+ {
+ case TensorStorageType::UNKNOWN:
+ return "TensorStorageType::UNKNOWN";
+ case TensorStorageType::BUFFER:
+ return "TensorStorageType::BUFFER";
+ case TensorStorageType::TEXTURE_ARRAY:
+ return "TensorStorageType::TEXTURE_ARRAY";
+ case TensorStorageType::TEXTURE_2D:
+ return "TensorStorageType::TEXTURE_2D";
+ case TensorStorageType::TEXTURE_3D:
+ return "TensorStorageType::TEXTURE_3D";
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return "TensorStorageType::SINGLE_TEXTURE_2D";
+ case TensorStorageType::IMAGE_BUFFER:
+ return "TensorStorageType::IMAGE_BUFFER";
+ }
+ return "";
+}
+
+TensorDescriptor::TensorDescriptor(TensorDescriptor &&desc)
+ : GPUObjectDescriptor(std::move(desc)), data_type(desc.data_type),
+ storage_type(desc.storage_type), layout(desc.layout), shape(desc.shape),
+ data(std::move(desc.data))
+{
+}
+TensorDescriptor &TensorDescriptor::operator=(TensorDescriptor &&desc)
+{
+ if (this != &desc)
+ {
+ std::swap(data_type, desc.data_type);
+ std::swap(storage_type, desc.storage_type);
+ std::swap(layout, desc.layout);
+ std::swap(shape, desc.shape);
+ data = std::move(desc.data);
+ GPUObjectDescriptor::operator=(std::move(desc));
+ }
+ return *this;
+}
+
+GPUResources TensorDescriptor::GetGPUResources() const
+{
+ GPUResources resources;
+ if (HasAxis(Axis::WIDTH))
+ {
+ resources.ints.push_back("width");
+ resources.ints.push_back("width_div2");
+ resources.ints.push_back("width_div4");
+ resources.ints.push_back("width_batched");
+ resources.ints.push_back("width_batched_div2");
+ resources.ints.push_back("width_batched_div4");
+ }
+ if (HasAxis(Axis::HEIGHT))
+ {
+ resources.ints.push_back("height");
+ }
+ if (HasAxis(Axis::CHANNELS))
+ {
+ resources.ints.push_back("slices");
+ resources.ints.push_back("channels");
+ }
+ if (HasAxis(Axis::BATCH))
+ {
+ resources.ints.push_back("batch");
+ }
+ if (HasAxis(Axis::DEPTH))
+ {
+ resources.ints.push_back("depth");
+ }
+ if (storage_type == TensorStorageType::BUFFER)
+ {
+ GPUBufferDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ desc.element_size = 4;
+ auto it1 = state_vars_.find("ElementsX2");
+ if (it1 != state_vars_.end() && it1->second == "true")
+ {
+ desc.element_size = 8;
+ }
+ auto it2 = state_vars_.find("ElementsX4");
+ if (it2 != state_vars_.end() && it2->second == "true")
+ {
+ desc.element_size = 16;
+ }
+ resources.buffers.push_back({"buffer", desc});
+ }
+ else if (storage_type == TensorStorageType::SINGLE_TEXTURE_2D ||
+ storage_type == TensorStorageType::TEXTURE_2D)
+ {
+ GPUImage2DDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ resources.images2d.push_back({"image2d", desc});
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_ARRAY)
+ {
+ GPUImage2DArrayDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ resources.image2d_arrays.push_back({"image2d_array", desc});
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_3D)
+ {
+ GPUImage3DDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ resources.images3d.push_back({"image3d", desc});
+ }
+ else if (storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ if (access_type_ == AccessType::READ)
+ {
+ GPUImageBufferDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ resources.image_buffers.push_back({"image_buffer", desc});
+ }
+ else
+ {
+ GPUBufferDescriptor desc;
+ desc.data_type = data_type;
+ desc.access_type = access_type_;
+ desc.element_size = 4;
+ resources.buffers.push_back({"buffer", desc});
+ }
+ }
+ return resources;
+}
+
+absl::Status TensorDescriptor::PerformSelector(const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const
+{
+ if (selector == "Width")
+ {
+ *result = GetWidth();
+ return absl::OkStatus();
+ }
+ else if (selector == "Height")
+ {
+ *result = "height";
+ return absl::OkStatus();
+ }
+ else if (selector == "Slices")
+ {
+ *result = "slices";
+ return absl::OkStatus();
+ }
+ else if (selector == "SliceStride")
+ {
+ *result = GetSliceStride();
+ return absl::OkStatus();
+ }
+ else if (selector == "Channels")
+ {
+ *result = "channels";
+ return absl::OkStatus();
+ }
+ else if (selector == "Batch")
+ {
+ if (HasAxis(Axis::BATCH))
+ {
+ *result = "batch";
+ }
+ else
+ {
+ *result = "1";
+ }
+ return absl::OkStatus();
+ }
+ else if (selector == "Depth")
+ {
+ *result = "depth";
+ return absl::OkStatus();
+ }
+ else if (selector == "SetBatchRef")
+ {
+ if (args.size() != 1)
+ {
+ return absl::InvalidArgumentError("Unsupported arguments in SetBatchRef selector");
+ }
+ state_vars_["batch_id"] = args[0];
+ *result = "";
+ return absl::OkStatus();
+ }
+ else if (selector == "Read")
+ {
+ return PerformReadSelector(args, template_args, result);
+ }
+ else if (selector == "Write")
+ {
+ return PerformWriteSelector(args, result);
+ }
+ else if (selector == "WriteLinear")
+ {
+ return PerformWriteLinearSelector(args, result);
+ }
+ else if (selector == "GetAddress")
+ {
+ return PerformGetAddressSelector(args, result);
+ }
+ else if (selector == "GetPtrWithSliceOffset")
+ {
+ return PerformGetPtrWithSliceOffsetSelector(args, result);
+ }
+ else if (selector == "GetWHOffset")
+ {
+ return PerformGetWHOffsetSelector(args, result);
+ }
+ else if (selector == "GetHandle")
+ {
+ return PerformGetHandleSelector(args, result);
+ }
+ else
+ {
+ return absl::NotFoundError(
+ absl::StrCat("TensorDescriptor don't have selector with name - ", selector));
+ }
+}
+
+absl::Status TensorDescriptor::PerformReadSelector(const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const
+{
+ DataType read_as_type = data_type;
+ if (!template_args.empty())
+ {
+ if (template_args.size() != 1)
+ {
+ return absl::NotFoundError("Unrecognized Read selector template arguments.");
+ }
+ else
+ {
+ RETURN_IF_ERROR(GetDataTypeFromTemplateArgs(template_args[0], &read_as_type));
+ }
+ }
+ if (args.size() == 1)
+ { // function overload for 1D linear types.
+ if (storage_type == TensorStorageType::BUFFER ||
+ storage_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ *result = Read(read_as_type, args[0]);
+ return absl::OkStatus();
+ }
+ else
+ {
+ return absl::InvalidArgumentError(
+ "Read selector with single argument can be used only with linear "
+ "storage types(BUFFER or IMAGE_BUFFER)");
+ }
+ }
+ std::string xc;
+ std::string yc;
+ std::string zc;
+ std::string sc;
+ std::string bc;
+ bool parsed = ParseCoordsFromArgs(args, 0, &xc, &yc, &zc, &sc, &bc);
+ if (args.size() < 2 || !parsed)
+ {
+ return absl::NotFoundError("Unrecognized Read selector");
+ }
+
+ *result = Read(read_as_type, GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
+ return absl::OkStatus();
+}
+
+absl::Status TensorDescriptor::GetLinkingContextFromWriteSelector(
+ const std::vector<std::string> &args, std::string *value_name, std::string *x_coord,
+ std::string *y_coord, std::string *s_coord) const
+{
+ std::string xc;
+ std::string yc;
+ std::string zc;
+ std::string sc;
+ std::string bc;
+ bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
+ if (args.size() < 2 || !parsed)
+ {
+ return absl::NotFoundError("Unrecognized Write selector");
+ }
+ *value_name = args[0];
+ if (HasAxis(Axis::BATCH) && !IsBatchedWidth())
+ {
+ *x_coord = absl::StrCat("((", xc, ") * batch + (", bc, "))");
+ }
+ else
+ {
+ *x_coord = absl::StrCat("(", xc, ")");
+ }
+ *y_coord = absl::StrCat("(", yc, ")");
+ *s_coord = absl::StrCat("(", sc, ")");
+ return absl::OkStatus();
+}
+
+absl::Status TensorDescriptor::PerformWriteSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ std::string xc;
+ std::string yc;
+ std::string zc;
+ std::string sc;
+ std::string bc;
+ bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
+ if (args.size() < 2 || !parsed)
+ {
+ return absl::NotFoundError("Unrecognized Write selector");
+ }
+ *result = Write(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
+ return absl::OkStatus();
+}
+
+absl::Status TensorDescriptor::PerformWriteLinearSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (storage_type != TensorStorageType::BUFFER && storage_type != TensorStorageType::IMAGE_BUFFER)
+ {
+ return absl::InvalidArgumentError("WriteLinear selector can be used only with linear "
+ "storages(BUFFER/IMAGE_BUFFER)");
+ }
+ if (args.size() != 2)
+ {
+ return absl::NotFoundError("Unrecognized WriteLinear selector");
+ }
+ *result = Write(args[0], "(" + args[1] + ")");
+ return absl::OkStatus();
+}
+
+std::string TensorDescriptor::Read(DataType read_as_type, const std::string &global_address) const
+{
+ const std::string read_as = read_as_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
+ std::string image_type;
+ if (storage_type == TensorStorageType::TEXTURE_2D ||
+ storage_type == TensorStorageType::SINGLE_TEXTURE_2D)
+ {
+ image_type = "image2d";
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_3D)
+ {
+ image_type = "image3d";
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_ARRAY)
+ {
+ image_type = "image2d_array";
+ }
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ if (read_as_type == data_type)
+ {
+ return absl::StrCat("buffer[", global_address, "]");
+ }
+ else
+ {
+ const std::string conversion =
+ read_as_type == DataType::FLOAT16 ? "convert_half4" : "convert_float4";
+ return absl::StrCat(conversion, "(buffer[", global_address, "])");
+ }
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ case TensorStorageType::TEXTURE_ARRAY:
+ return absl::StrCat(read_as, "(", image_type,
+ ", " + TextureAddressModeToString(ModeFromState()) + ", ", global_address,
+ ")");
+ case TensorStorageType::IMAGE_BUFFER:
+ return absl::StrCat(read_as, "(image_buffer, ", global_address, ")");
+ case TensorStorageType::UNKNOWN:
+ return "";
+ }
+ return "";
+}
+
+std::string TensorDescriptor::Write(const std::string &var_name,
+ const std::string &global_address) const
+{
+ std::string image_type;
+ if (storage_type == TensorStorageType::TEXTURE_2D ||
+ storage_type == TensorStorageType::SINGLE_TEXTURE_2D)
+ {
+ image_type = "image2d";
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_3D)
+ {
+ image_type = "image3d";
+ }
+ else if (storage_type == TensorStorageType::TEXTURE_ARRAY)
+ {
+ image_type = "image2d_array";
+ }
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ return absl::StrCat("buffer[", global_address, "] = ", var_name, ";\n");
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ case TensorStorageType::TEXTURE_ARRAY:
+ return absl::StrCat(GetWriteImageFromDataType(data_type), "(", image_type, ", ",
+ global_address, ", ", var_name, ");\n");
+ case TensorStorageType::UNKNOWN:
+ return "";
+ }
+ return "";
+}
+
+absl::Status TensorDescriptor::PerformGetAddressSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ std::string xc;
+ std::string yc;
+ std::string zc;
+ std::string sc;
+ std::string bc;
+ bool parsed = ParseCoordsFromArgs(args, 1, &xc, &yc, &zc, &sc, &bc);
+ if (args.size() < 3 || !parsed)
+ {
+ return absl::NotFoundError("Unrecognized GetAddress selector");
+ }
+
+ *result = DeclareAddress(args[0], GetGlobalAddressNoDeclaration(xc, yc, zc, sc, bc));
+ return absl::OkStatus();
+}
+
+absl::Status
+TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (storage_type != TensorStorageType::BUFFER)
+ {
+ return absl::InvalidArgumentError(
+ "GetPtrWithSliceOffset selector can be used only with BUFFER");
+ }
+ if (args.size() != 1)
+ {
+ return absl::NotFoundError(
+ absl::StrCat("GetPtrWithSliceOffset require one argument(slice coordinate), but ",
+ args.size(), " was passed"));
+ }
+ *result = absl::StrCat("buffer + ", args[0], " * ", GetSliceStride());
+ return absl::OkStatus();
+}
+
+absl::Status TensorDescriptor::PerformGetWHOffsetSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (storage_type != TensorStorageType::BUFFER && storage_type != TensorStorageType::IMAGE_BUFFER)
+ {
+ return absl::InvalidArgumentError(
+ "GetWHOffset selector can be used only with BUFFER/IMAGE_BUFFER");
+ }
+ if (args.size() != 2)
+ {
+ return absl::NotFoundError(absl::StrCat(
+ "GetWHOffset require two arguments(X and Y coordinates), but ", args.size(), " was passed"));
+ }
+ if (HasAxis(Axis::BATCH) && !IsBatchedWidth())
+ {
+ auto it = state_vars_.find("batch_id");
+ std::string batch_id;
+ if (it == state_vars_.end())
+ {
+ return absl::NotFoundError(
+ "Not found batch_id. Should be setted up by SetBatchRef(). method");
+ }
+ else
+ {
+ batch_id = it->second;
+ }
+ *result = absl::StrCat("((", args[1], ") * ", GetWidth(), " + (", args[0], ")) * batch + (",
+ batch_id, ")");
+ }
+ else
+ {
+ *result = absl::StrCat("(", args[1], ") * ", GetWidth(), " + (", args[0], ")");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status TensorDescriptor::PerformGetHandleSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (!args.empty())
+ {
+ return absl::NotFoundError(
+ absl::StrCat("GetHandle does not require arguments, but ", args.size(), " was passed"));
+ }
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ *result = "buffer";
+ return absl::OkStatus();
+ case TensorStorageType::IMAGE_BUFFER:
+ if (access_type_ == AccessType::READ)
+ {
+ *result = "image_buffer";
+ }
+ else
+ {
+ *result = "buffer";
+ }
+ return absl::OkStatus();
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ *result = "image2d";
+ return absl::OkStatus();
+ case TensorStorageType::TEXTURE_ARRAY:
+ *result = "image2d_array";
+ return absl::OkStatus();
+ case TensorStorageType::TEXTURE_3D:
+ *result = "image3d";
+ return absl::OkStatus();
+ case TensorStorageType::UNKNOWN:
+ return absl::UnavailableError("Unknown type");
+ }
+ return absl::UnavailableError("Unknown type");
+}
+
+std::string TensorDescriptor::DeclareAddress(const std::string &var_name,
+ const std::string &address) const
+{
+ return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address, ";");
+}
+
+std::string TensorDescriptor::StorageTypeToAddressType() const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ return "int";
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return "int2";
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return "int4";
+ case TensorStorageType::UNKNOWN:
+ return "";
+ }
+ return "";
+}
+
+std::string TensorDescriptor::GetGlobalAddressNoDeclarationWHS(const std::string &x,
+ const std::string &y,
+ const std::string &s) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ {
+ return absl::Substitute("((($2) * height + ($1)) * $3 + ($0))", x, y, s, GetWidth());
+ }
+ case TensorStorageType::TEXTURE_2D:
+ return absl::Substitute("(int2)(($0), ($1) * slices + ($2))", x, y, s);
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return absl::StrCat("(int2)(", x, ", ", y, ")");
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return absl::StrCat("(int4)(", x, ", ", y, ", ", s, ", 0)");
+ case TensorStorageType::UNKNOWN:
+ return "error";
+ }
+ return "error";
+}
+
+std::string TensorDescriptor::GetGlobalAddressNoDeclarationWHSB(const std::string &x,
+ const std::string &y,
+ const std::string &s,
+ const std::string &b) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ return absl::Substitute("(((($3) * height + $2) * width + ($1)) * batch + ($0))", b, x, y, s);
+ case TensorStorageType::TEXTURE_2D:
+ return absl::Substitute("(int2)(($0) * batch + ($1), ($2) * slices + ($3))", x, b, y, s);
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return absl::Substitute("(int2)(($0) * batch + ($1), ($2))", x, b, y);
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return absl::Substitute("(int4)(($0) * batch + ($1), ($2), ($3), 0)", x, b, y, s);
+ default:
+ throw std::runtime_error("Unknown storage type");
+ }
+}
+
+std::string TensorDescriptor::GetGlobalAddressNoDeclarationWHDS(const std::string &x,
+ const std::string &y,
+ const std::string &z,
+ const std::string &s) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ {
+ return absl::Substitute("(((($3) * slices + ($2)) * height + ($1)) * $4 + ($0))", x, y, s, z,
+ GetWidth());
+ }
+ case TensorStorageType::TEXTURE_2D:
+ return absl::Substitute("(int2)(($0) * depth + ($1), ($2) * slices + ($3))", x, z, y, s);
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return absl::Substitute("(int2)(($0) * depth + ($1), ($2))", x, z, y);
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return absl::Substitute("(int4)(($0), ($1), ($2) * slices + ($3), 0)", x, y, z, s);
+ case TensorStorageType::UNKNOWN:
+ return "error";
+ }
+ return "error";
+}
+
+std::string TensorDescriptor::GetGlobalAddressNoDeclarationWHDSB(const std::string &x,
+ const std::string &y,
+ const std::string &z,
+ const std::string &s,
+ const std::string &b) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ return absl::Substitute("((((($4) * slices + ($3)) * height + $2) * width + ($1)) * batch + "
+ "($0))",
+ b, x, y, s, z);
+ case TensorStorageType::TEXTURE_2D:
+ return absl::Substitute("(int2)((($0) * batch + ($1)) * depth + ($2), ($3) * slices + ($4))",
+ x, b, z, y, s);
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return absl::Substitute("(int2)((($0) * batch + ($1)) * depth + ($2), ($3))", x, b, z, y);
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return absl::Substitute("(int4)(($0) * batch + ($1), ($2), ($3) * slices + ($4), 0)", x, b, y,
+ z, s);
+ default:
+ throw std::runtime_error("Unknown storage type");
+ }
+}
+
+std::string TensorDescriptor::GetGlobalAddressNoDeclaration(const std::string &xc,
+ const std::string &yc,
+ const std::string &zc,
+ const std::string &sc,
+ const std::string &bc) const
+{
+ if (layout == Layout::HWC || (IsBatchedWidth() && layout == Layout::BHWC))
+ {
+ return GetGlobalAddressNoDeclarationWHS(xc, yc, sc);
+ }
+ else if (layout == Layout::BHWC)
+ {
+ return GetGlobalAddressNoDeclarationWHSB(xc, yc, sc, bc);
+ }
+ else if (layout == Layout::HWDC || (IsBatchedWidth() && layout == Layout::BHWDC))
+ {
+ return GetGlobalAddressNoDeclarationWHDS(xc, yc, zc, sc);
+ }
+ else if (layout == Layout::BHWDC)
+ {
+ return GetGlobalAddressNoDeclarationWHDSB(xc, yc, zc, sc, bc);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported layout");
+ }
+}
+
+absl::Status TensorDescriptor::GetDataTypeFromTemplateArgs(const std::string &template_arg,
+ DataType *result) const
+{
+ std::string read_type = template_arg;
+ if (read_type == "FLT" || read_type == "ACCUM_FLT")
+ {
+ auto it = state_vars_.find(read_type);
+ if (it == state_vars_.end())
+ {
+ return absl::UnavailableError(
+ absl::StrCat("Read selector template argument ", read_type, " uninitialized."));
+ }
+ else
+ {
+ read_type = it->second;
+ }
+ }
+
+ if (read_type == "half")
+ {
+ *result = DataType::FLOAT16;
+ }
+ else if (read_type == "float")
+ {
+ *result = DataType::FLOAT32;
+ }
+ else
+ {
+ return absl::NotFoundError(
+ absl::StrCat("Unrecognized Read selector template argument - ", read_type));
+ }
+ return absl::OkStatus();
+}
+
+bool TensorDescriptor::HasAxis(Axis axis) const
+{
+ if (axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::CHANNELS)
+ {
+ return true;
+ }
+ if (axis == Axis::BATCH && (layout == Layout::BHWC || layout == Layout::BHWDC))
+ {
+ return true;
+ }
+ if (axis == Axis::DEPTH && (layout == Layout::HWDC || layout == Layout::BHWDC))
+ {
+ return true;
+ }
+ return false;
+}
+
+void TensorDescriptor::SetTextureAddressMode(TextureAddressMode mode)
+{
+ if (mode == TextureAddressMode::ZERO)
+ {
+ state_vars_["TextureMode"] = "ZERO";
+ }
+ else
+ {
+ state_vars_["TextureMode"] = "DONT_CARE";
+ }
+}
+
+bool TensorDescriptor::ParseCoordsFromArgs(const std::vector<std::string> &args, int offset,
+ std::string *xc, std::string *yc, std::string *zc,
+ std::string *sc, std::string *bc) const
+{
+ if (HasAxis(Axis::WIDTH))
+ {
+ if ((size_t)offset >= args.size())
+ return false;
+ *xc = args[offset++];
+ }
+ if (HasAxis(Axis::HEIGHT))
+ {
+ if ((size_t)offset >= args.size())
+ return false;
+ *yc = args[offset++];
+ }
+ if (HasAxis(Axis::DEPTH))
+ {
+ if ((size_t)offset >= args.size())
+ return false;
+ *zc = args[offset++];
+ }
+ if (HasAxis(Axis::CHANNELS))
+ {
+ if ((size_t)offset >= args.size())
+ {
+ auto it = state_vars_.find("slice_id");
+ if (it == state_vars_.end())
+ {
+ return false;
+ }
+ else
+ {
+ *sc = it->second;
+ }
+ }
+ else
+ {
+ *sc = args[offset++];
+ }
+ }
+ if (HasAxis(Axis::BATCH) && !IsBatchedWidth())
+ {
+ if ((size_t)offset >= args.size())
+ {
+ auto it = state_vars_.find("batch_id");
+ if (it == state_vars_.end())
+ {
+ return false;
+ }
+ else
+ {
+ *bc = it->second;
+ }
+ }
+ else
+ {
+ *bc = args[offset++];
+ }
+ }
+ return true;
+}
+
+bool TensorDescriptor::IsBatchedWidth() const
+{
+ auto it = state_vars_.find("BatchedWidth");
+ return it != state_vars_.end() && it->second == "true";
+}
+
+std::string TensorDescriptor::GetWidth() const
+{
+ std::string div;
+ auto it1 = state_vars_.find("ElementsX2");
+ if (it1 != state_vars_.end() && it1->second == "true")
+ {
+ div = "_div2";
+ }
+ auto it2 = state_vars_.find("ElementsX4");
+ if (it2 != state_vars_.end() && it2->second == "true")
+ {
+ div = "_div4";
+ }
+ auto it = state_vars_.find("BatchedWidth");
+ if (it != state_vars_.end() && it->second == "true")
+ {
+ return "width_batched" + div;
+ }
+ else
+ {
+ return "width" + div;
+ }
+}
+
+std::string TensorDescriptor::GetSliceStride() const
+{
+ if (IsBatchedWidth())
+ {
+ return GetWidth() + " * height";
+ }
+ else
+ {
+ if (HasAxis(Axis::BATCH))
+ {
+ return GetWidth() + " * height * batch";
+ }
+ else
+ {
+ return GetWidth() + " * height";
+ }
+ }
+}
+
+TextureAddressMode TensorDescriptor::ModeFromState() const
+{
+ auto it = state_vars_.find("TextureMode");
+ if (it != state_vars_.end())
+ {
+ if (it->second == "ZERO")
+ {
+ return TextureAddressMode::ZERO;
+ }
+ else
+ {
+ return TextureAddressMode::DONT_CARE;
+ }
+ }
+ else
+ {
+ return TextureAddressMode::DONT_CARE;
+ }
+}
+
+void TensorDescriptor::UploadData(const InternalTensor<HWC, DataType::FLOAT32> &src)
+{
+ shape = BHWDC(1, src.shape.h, src.shape.w, 1, src.shape.c);
+ UploadData(absl::MakeConstSpan(src.data));
+}
+
+void TensorDescriptor::UploadData(const InternalTensor<Linear, DataType::FLOAT32> &src)
+{
+ shape = BHWDC(1, 1, 1, 1, src.shape.v);
+ UploadData(absl::MakeConstSpan(src.data));
+}
+
+void TensorDescriptor::UploadData(absl::Span<const float> src)
+{
+ int aligned_channels =
+ storage_type == TensorStorageType::SINGLE_TEXTURE_2D ? shape.c : AlignByN(shape.c, 4);
+ int elements_count = shape.b * shape.w * shape.h * shape.d * aligned_channels;
+ data.resize(elements_count * SizeOf(data_type));
+ if (data_type == DataType::FLOAT32)
+ {
+ float *gpu_data = reinterpret_cast<float *>(data.data());
+ DataFromBHWDC(src, shape, *this, absl::MakeSpan(gpu_data, elements_count));
+ }
+}
+
+bool TensorDescriptor::SupportsZeroClamp(const Axis &axis) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::UNKNOWN:
+ return false;
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ return false;
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return axis == Axis::WIDTH || axis == Axis::HEIGHT;
+ case TensorStorageType::TEXTURE_3D:
+ return axis == Axis::WIDTH || axis == Axis::HEIGHT || axis == Axis::DEPTH;
+ }
+ return false;
+}
+
+bool TensorDescriptor::CanReadOutOfBorder(const Axis &) const
+{
+ switch (storage_type)
+ {
+ case TensorStorageType::UNKNOWN:
+ return false;
+ case TensorStorageType::BUFFER:
+ return false;
+ case TensorStorageType::IMAGE_BUFFER:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_3D:
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ case TensorStorageType::TEXTURE_ARRAY:
+ return true;
+ }
+ return false;
+}
+
+bool TensorDescriptor::IsLinear() const
+{
+ return storage_type == TensorStorageType::BUFFER ||
+ storage_type == TensorStorageType::IMAGE_BUFFER;
+}
+
+bool TensorDescriptor::ReturnsZeroForNegOneRead() const
+{
+ return storage_type == TensorStorageType::IMAGE_BUFFER;
+}
+
+namespace
+{
+int GetLinearIndex(const TensorDescriptor &desc, const BHWDC &shape, int b, int x, int y, int d,
+ int s, int sub_c)
+{
+ const int slices = DivideRoundUp(shape.c, 4);
+ switch (desc.storage_type)
+ {
+ case TensorStorageType::BUFFER:
+ case TensorStorageType::IMAGE_BUFFER:
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return ((((d * slices + s) * shape.h + y) * shape.w + x) * shape.b + b) * 4 +
+ sub_c; // DSHWBC4
+ case TensorStorageType::TEXTURE_2D:
+ return ((((y * slices + s) * shape.w + x) * shape.b + b) * shape.d + d) * 4 +
+ sub_c; // HSWBDC4
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return (((y * shape.w + x) * shape.b + b) * shape.d + d) * shape.c + sub_c; // HWBDC
+ default:
+ return -1;
+ }
+ return -1;
+}
+
+int GetChannelsAlignment(const TensorDescriptor &desc, const BHWDC &shape)
+{
+ return desc.storage_type == TensorStorageType::SINGLE_TEXTURE_2D ? shape.c : 4;
+}
+} // namespace
+
+template <typename T>
+void DataFromBHWDC(absl::Span<const float> src, const BHWDC &shape, const TensorDescriptor &desc,
+ absl::Span<T> dst)
+{
+ const int channels_alignment = GetChannelsAlignment(desc, shape);
+ const int slices = DivideRoundUp(shape.c, 4);
+ for (int b = 0; b < shape.b; ++b)
+ {
+ for (int s = 0; s < slices; ++s)
+ {
+ for (int y = 0; y < shape.h; ++y)
+ {
+ for (int x = 0; x < shape.w; ++x)
+ {
+ for (int d = 0; d < shape.d; ++d)
+ {
+ for (int c = 0; c < channels_alignment; ++c)
+ {
+ float value;
+ if (s * 4 + c < shape.c)
+ {
+ const int cpu_index = shape.LinearIndex({b, y, x, d, s * 4 + c});
+ value = src[cpu_index];
+ }
+ else
+ {
+ value = 0.0f;
+ }
+ int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
+ dst[gpu_index] = value;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template void DataFromBHWDC<float>(absl::Span<const float> src, const BHWDC &shape,
+ const TensorDescriptor &desc, absl::Span<float> dst);
+
+template <typename T>
+void DataToBHWDC(absl::Span<const T> src, const BHWDC &shape, const TensorDescriptor &desc,
+ absl::Span<float> dst)
+{
+ const int channels_alignment = GetChannelsAlignment(desc, shape);
+ const int slices = DivideRoundUp(shape.c, 4);
+ for (int b = 0; b < shape.b; ++b)
+ {
+ for (int s = 0; s < slices; ++s)
+ {
+ for (int y = 0; y < shape.h; ++y)
+ {
+ for (int x = 0; x < shape.w; ++x)
+ {
+ for (int d = 0; d < shape.d; ++d)
+ {
+ for (int c = 0; c < channels_alignment; ++c)
+ {
+ if (s * 4 + c >= shape.c)
+ {
+ continue;
+ }
+ int cpu_index = shape.LinearIndex({b, y, x, d, s * 4 + c});
+ int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
+ dst[cpu_index] = src[gpu_index];
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template void DataToBHWDC<float>(absl::Span<const float> src, const BHWDC &shape,
+ const TensorDescriptor &desc, absl::Span<float> dst);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_H__
+
+#include <cstddef>
+#include <string>
+
+#include "absl/types/span.h"
+#include "GpuObject.h"
+#include "DataType.h"
+#include "InternalTensor.h"
+#include "Shape.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class TextureAddressMode
+{
+ DONT_CARE, // translated to CLK_ADDRESS_NONE
+ ZERO, // translated to CLK_ADDRESS_CLAMP
+};
+
+std::string TextureAddressModeToString(TextureAddressMode address_mode);
+
+enum class TensorStorageType
+{
+ UNKNOWN,
+ BUFFER,
+ IMAGE_BUFFER,
+ TEXTURE_2D,
+ TEXTURE_3D,
+ TEXTURE_ARRAY,
+ SINGLE_TEXTURE_2D
+};
+
+struct TensorDescriptor : public GPUObjectDescriptor
+{
+ TensorDescriptor() = default;
+ TensorDescriptor(DataType dt, TensorStorageType st, Layout l)
+ : data_type(dt), storage_type(st), layout(l)
+ {
+ }
+
+ TensorDescriptor(const TensorDescriptor &) = default;
+ TensorDescriptor &operator=(const TensorDescriptor &) = default;
+ TensorDescriptor(TensorDescriptor &&desc);
+ TensorDescriptor &operator=(TensorDescriptor &&desc);
+
+ bool operator==(const TensorDescriptor &d) const
+ {
+ return data_type == d.data_type && storage_type == d.storage_type && layout == d.layout;
+ }
+
+ bool operator!=(const TensorDescriptor &d) const { return !(*this == d); }
+
+ absl::Status PerformSelector(const std::string &selector, const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const override;
+
+ GPUResources GetGPUResources() const override;
+
+ absl::Status CreateGPUObject(CLContext *context, GPUObjectPtr *result) const override;
+ void Release() override { data.clear(); }
+
+ bool HasAxis(Axis axis) const;
+ void SetTextureAddressMode(TextureAddressMode mode);
+
+ absl::Status GetLinkingContextFromWriteSelector(const std::vector<std::string> &args,
+ std::string *value_name, std::string *x_coord,
+ std::string *y_coord, std::string *s_coord) const;
+
+ void UploadData(const InternalTensor<HWC, DataType::FLOAT32> &src);
+ void UploadData(const InternalTensor<Linear, DataType::FLOAT32> &src);
+
+ bool SupportsZeroClamp(const Axis &axis) const;
+ bool CanReadOutOfBorder(const Axis &axis) const;
+ bool IsLinear() const;
+
+ // applicable only for types that: IsLinear -> true.
+ // In this case for address we have 1d component - addr (int)
+ // If for addr == -1 this linear storage type returns FLT4(0.0), this function
+ // returns true, otherwise false
+ bool ReturnsZeroForNegOneRead() const;
+
+ DataType data_type = DataType::UNKNOWN;
+ TensorStorageType storage_type = TensorStorageType::UNKNOWN;
+ // This field describes logical layout, actual(physical) GPU layout can be
+ // totally different.
+ Layout layout = Layout::UNKNOWN; // Supported layouts is HWC, BHWC, HWDC, BHWDC
+
+ // optional
+ BHWDC shape;
+ std::vector<uint8_t> data;
+
+private:
+ absl::Status PerformReadSelector(const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const;
+
+ absl::Status PerformGetAddressSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ absl::Status PerformGetPtrWithSliceOffsetSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ absl::Status PerformGetWHOffsetSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ absl::Status PerformGetHandleSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ std::string DeclareAddress(const std::string &var_name, const std::string &address) const;
+
+ std::string StorageTypeToAddressType() const;
+
+ absl::Status PerformWriteSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ absl::Status PerformWriteLinearSelector(const std::vector<std::string> &args,
+ std::string *result) const;
+
+ std::string Read(DataType read_as_type, const std::string &global_address) const;
+ std::string Write(const std::string &var_name, const std::string &global_address) const;
+
+ bool IsBatchedWidth() const;
+
+ std::string GetWidth() const;
+ std::string GetSliceStride() const;
+
+ TextureAddressMode ModeFromState() const;
+
+ absl::Status GetDataTypeFromTemplateArgs(const std::string &template_arg, DataType *result) const;
+
+ std::string GetGlobalAddressNoDeclarationWHS(const std::string &x, const std::string &y,
+ const std::string &s) const;
+ std::string GetGlobalAddressNoDeclarationWHSB(const std::string &x, const std::string &y,
+ const std::string &s, const std::string &b) const;
+ std::string GetGlobalAddressNoDeclarationWHDS(const std::string &x, const std::string &y,
+ const std::string &z, const std::string &s) const;
+ std::string GetGlobalAddressNoDeclarationWHDSB(const std::string &x, const std::string &y,
+ const std::string &z, const std::string &s,
+ const std::string &b) const;
+ std::string GetGlobalAddressNoDeclaration(const std::string &xc, const std::string &yc,
+ const std::string &zc, const std::string &sc,
+ const std::string &bc) const;
+
+ bool ParseCoordsFromArgs(const std::vector<std::string> &args, int offset, std::string *xc,
+ std::string *yc, std::string *zc, std::string *sc,
+ std::string *bc) const;
+
+ void UploadData(absl::Span<const float> src);
+};
+
+template <typename T>
+void DataFromBHWDC(absl::Span<const float> src, const BHWDC &shape, const TensorDescriptor &desc,
+ absl::Span<T> dst);
+
+template <typename T>
+void DataToBHWDC(absl::Span<const T> src, const BHWDC &shape, const TensorDescriptor &desc,
+ absl::Span<float> dst);
+
+std::string ToString(TensorStorageType type);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TensorTypeUtil.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ObjectType ToObjectType(TensorStorageType type)
+{
+ switch (type)
+ {
+ case TensorStorageType::IMAGE_BUFFER:
+ case TensorStorageType::BUFFER:
+ return ObjectType::OPENCL_BUFFER;
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ case TensorStorageType::TEXTURE_2D:
+ case TensorStorageType::TEXTURE_ARRAY:
+ case TensorStorageType::TEXTURE_3D:
+ return ObjectType::OPENCL_TEXTURE;
+ default:
+ return ObjectType::UNKNOWN;
+ }
+}
+
+DataLayout ToDataLayout(TensorStorageType type)
+{
+ switch (type)
+ {
+ case TensorStorageType::BUFFER:
+ return DataLayout::DHWC4;
+ case TensorStorageType::IMAGE_BUFFER:
+ return DataLayout::DHWC4;
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ return DataLayout::BHWC;
+ case TensorStorageType::TEXTURE_2D:
+ return DataLayout::HDWC4;
+ case TensorStorageType::TEXTURE_ARRAY:
+ return DataLayout::DHWC4;
+ case TensorStorageType::TEXTURE_3D:
+ return DataLayout::DHWC4;
+ default:
+ return DataLayout::UNKNOWN;
+ }
+}
+
+TensorStorageType ToTensorStorageType(ObjectType object_type, DataLayout data_layout)
+{
+ switch (object_type)
+ {
+ case ObjectType::OPENCL_BUFFER:
+ return TensorStorageType::BUFFER;
+ case ObjectType::OPENCL_TEXTURE:
+ switch (data_layout)
+ {
+ case DataLayout::BHWC:
+ return TensorStorageType::SINGLE_TEXTURE_2D;
+ case DataLayout::DHWC4:
+ return TensorStorageType::TEXTURE_ARRAY;
+ case DataLayout::HDWC4:
+ return TensorStorageType::TEXTURE_2D;
+ default:
+ return TensorStorageType::UNKNOWN;
+ }
+ default:
+ return TensorStorageType::UNKNOWN;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_UTIL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_UTIL_H__
+
+#include "Api.h"
+#include "TensorType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ObjectType ToObjectType(TensorStorageType type);
+
+DataLayout ToDataLayout(TensorStorageType type);
+
+TensorStorageType ToTensorStorageType(ObjectType object_type, DataLayout data_layout);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_TENSOR_TYPE_UTIL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Texture2d.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+// Creates new 4-channel 2D texture with cl_channel_type elements
+absl::Status CreateTexture2D(int width, int height, DataType type, void *data, CLContext *context,
+ Texture2D *result)
+{
+ cl_mem texture;
+ cl_channel_type channel_type = DataTypeToChannelType(type);
+ RETURN_IF_ERROR(
+ CreateRGBAImage2D(context->context(), width, height, channel_type, data, &texture));
+ *result = Texture2D(texture, width, height, channel_type);
+
+ return absl::OkStatus();
+}
+} // namespace
+
+Texture2DDescriptor::Texture2DDescriptor(Texture2DDescriptor &&desc)
+ : GPUObjectDescriptor(std::move(desc)), element_type(desc.element_type),
+ normalized(desc.normalized), normalized_type(desc.normalized_type), size(desc.size),
+ data(std::move(desc.data))
+{
+}
+
+Texture2DDescriptor &Texture2DDescriptor::operator=(Texture2DDescriptor &&desc)
+{
+ if (this != &desc)
+ {
+ std::swap(element_type, desc.element_type);
+ std::swap(normalized, desc.normalized);
+ std::swap(normalized_type, desc.normalized_type);
+ std::swap(size, desc.size);
+ data = std::move(desc.data);
+ GPUObjectDescriptor::operator=(std::move(desc));
+ }
+ return *this;
+}
+
+void Texture2DDescriptor::Release() { data.clear(); }
+
+GPUResources Texture2DDescriptor::GetGPUResources() const
+{
+ GPUResources resources;
+ GPUImage2DDescriptor desc;
+ desc.data_type = element_type;
+ desc.access_type = access_type_;
+ resources.images2d.push_back({"tex2d", desc});
+ return resources;
+}
+
+absl::Status Texture2DDescriptor::PerformSelector(const std::string &selector,
+ const std::vector<std::string> &args,
+ const std::vector<std::string> &,
+ std::string *result) const
+{
+ if (selector == "Read")
+ {
+ return PerformReadSelector(args, result);
+ }
+ else
+ {
+ return absl::NotFoundError(
+ absl::StrCat("Texture2DDescriptor don't have selector with name - ", selector));
+ }
+}
+
+absl::Status Texture2DDescriptor::PerformReadSelector(const std::vector<std::string> &args,
+ std::string *result) const
+{
+ if (args.size() != 2)
+ {
+ return absl::NotFoundError(absl::StrCat("Texture2DDescriptor Read require two arguments, but ",
+ args.size(), " was passed"));
+ }
+ std::string read;
+ switch (element_type)
+ {
+ case DataType::FLOAT32:
+ read = "read_imagef";
+ break;
+ case DataType::FLOAT16:
+ read = "read_imageh";
+ break;
+ case DataType::INT8:
+ case DataType::INT16:
+ case DataType::INT32:
+ if (normalized)
+ {
+ read = normalized_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
+ }
+ else
+ {
+ read = "read_imagei";
+ }
+ break;
+ case DataType::UINT8:
+ case DataType::UINT16:
+ case DataType::UINT32:
+ if (normalized)
+ {
+ read = normalized_type == DataType::FLOAT16 ? "read_imageh" : "read_imagef";
+ }
+ else
+ {
+ read = "read_imageui";
+ }
+ break;
+ default:
+ read = "unknown_type";
+ break;
+ }
+ *result = absl::StrCat(read, "(tex2d, smp_none, (int2)(", args[0], ", " + args[1] + "))");
+ return absl::OkStatus();
+}
+
+absl::Status Texture2DDescriptor::CreateGPUObject(CLContext *context, GPUObjectPtr *result) const
+{
+ Texture2D gpu_texture;
+ RETURN_IF_ERROR(gpu_texture.CreateFromTexture2DDescriptor(*this, context));
+ *result = absl::make_unique<Texture2D>(std::move(gpu_texture));
+ return absl::OkStatus();
+}
+
+Texture2D::Texture2D(cl_mem texture, int width, int height, cl_channel_type type)
+ : texture_(texture), width_(width), height_(height), channel_type_(type)
+{
+}
+
+Texture2D::Texture2D(Texture2D &&texture)
+ : texture_(texture.texture_), width_(texture.width_), height_(texture.height_),
+ channel_type_(texture.channel_type_)
+{
+ texture.texture_ = nullptr;
+ texture.width_ = 0;
+ texture.height_ = 0;
+}
+
+Texture2D &Texture2D::operator=(Texture2D &&texture)
+{
+ if (this != &texture)
+ {
+ Release();
+ std::swap(channel_type_, texture.channel_type_);
+ std::swap(width_, texture.width_);
+ std::swap(height_, texture.height_);
+ std::swap(texture_, texture.texture_);
+ }
+ return *this;
+}
+
+void Texture2D::Release()
+{
+ if (texture_)
+ {
+ clReleaseMemObject(texture_);
+ texture_ = nullptr;
+ width_ = 0;
+ height_ = 0;
+ }
+}
+
+absl::Status Texture2D::GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const
+{
+ const auto *texture_desc = dynamic_cast<const Texture2DDescriptor *>(obj_ptr);
+ if (!texture_desc)
+ {
+ return absl::InvalidArgumentError("Expected Texture2DDescriptor on input.");
+ }
+
+ resources->images2d.push_back({"tex2d", texture_});
+ return absl::OkStatus();
+}
+
+absl::Status Texture2D::CreateFromTexture2DDescriptor(const Texture2DDescriptor &desc,
+ CLContext *context)
+{
+ width_ = desc.size.x;
+ height_ = desc.size.y;
+ channel_type_ = DataTypeToChannelType(desc.element_type, desc.normalized);
+ uint8_t *data_ptr = desc.data.empty() ? nullptr : const_cast<unsigned char *>(desc.data.data());
+ return CreateRGBAImage2D(context->context(), desc.size.x, desc.size.y, channel_type_, data_ptr,
+ &texture_);
+}
+
+// Creates new 4-channel 2D texture with f32 elements
+absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext *context, Texture2D *result)
+{
+ return CreateTexture2D(width, height, DataType::FLOAT32, nullptr, context, result);
+}
+
+// Creates new 4-channel 2D texture with f16 elements
+absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext *context, Texture2D *result)
+{
+ return CreateTexture2D(width, height, DataType::FLOAT16, nullptr, context, result);
+}
+
+absl::Status CreateTexture2DRGBA(DataType type, int width, int height, CLContext *context,
+ Texture2D *result)
+{
+ return CreateTexture2D(width, height, type, nullptr, context, result);
+}
+
+absl::Status CreateTexture2DRGBA(DataType type, int width, int height, void *data,
+ CLContext *context, Texture2D *result)
+{
+ return CreateTexture2D(width, height, type, data, context, result);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_TEXTURE2D_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_TEXTURE2D_H__
+
+#include "absl/strings/str_cat.h"
+#include "absl/types/span.h"
+#include "ClCommandQueue.h"
+#include "ClContext.h"
+#include "GpuObject.h"
+#include "OpenclWrapper.h"
+#include "TensorType.h"
+#include "Util.h"
+#include "DataType.h"
+#include "Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+struct Texture2DDescriptor : public GPUObjectDescriptor
+{
+ DataType element_type;
+ bool normalized = false; // used with INT data types, if normalized, we read
+ // in kernel float data.
+ DataType normalized_type; // can be FLOAT32 or FLOAT16, using with normalized
+ // = true
+
+ // optional
+ int2 size = int2(0, 0);
+ std::vector<uint8_t> data;
+
+ Texture2DDescriptor() = default;
+ Texture2DDescriptor(const Texture2DDescriptor &) = default;
+ Texture2DDescriptor &operator=(const Texture2DDescriptor &) = default;
+ Texture2DDescriptor(Texture2DDescriptor &&desc);
+ Texture2DDescriptor &operator=(Texture2DDescriptor &&desc);
+
+ absl::Status PerformSelector(const std::string &selector, const std::vector<std::string> &args,
+ const std::vector<std::string> &template_args,
+ std::string *result) const override;
+
+ GPUResources GetGPUResources() const override;
+ absl::Status PerformReadSelector(const std::vector<std::string> &args, std::string *result) const;
+
+ absl::Status CreateGPUObject(CLContext *context, GPUObjectPtr *result) const override;
+ void Release() override;
+};
+
+// Texture2D represent formatted GPU data storage.
+// Texture2D is moveable but not copyable.
+class Texture2D : public GPUObject
+{
+public:
+ Texture2D() {} // just for using Texture2D as a class members
+ Texture2D(cl_mem texture, int width, int height, cl_channel_type type);
+
+ // Move only
+ Texture2D(Texture2D &&texture);
+ Texture2D &operator=(Texture2D &&texture);
+ Texture2D(const Texture2D &) = delete;
+ Texture2D &operator=(const Texture2D &) = delete;
+
+ virtual ~Texture2D() { Release(); }
+
+ cl_mem GetMemoryPtr() const { return texture_; }
+
+ // Writes data to a texture. Data should point to a region that
+ // has exact width * height * sizeof(pixel) bytes.
+ template <typename T> absl::Status WriteData(CLCommandQueue *queue, const absl::Span<T> data);
+
+ // Reads data from Texture2D into CPU memory.
+ template <typename T> absl::Status ReadData(CLCommandQueue *queue, std::vector<T> *result) const;
+
+ absl::Status GetGPUResources(const GPUObjectDescriptor *obj_ptr,
+ GPUResourcesWithValue *resources) const override;
+
+ absl::Status CreateFromTexture2DDescriptor(const Texture2DDescriptor &desc, CLContext *context);
+
+private:
+ void Release();
+
+ cl_mem texture_ = nullptr;
+ int width_;
+ int height_;
+ cl_channel_type channel_type_;
+};
+
+using Texture2DPtr = std::shared_ptr<Texture2D>;
+
+// Creates new 4-channel 2D texture with f32 elements
+absl::Status CreateTexture2DRGBA32F(int width, int height, CLContext *context, Texture2D *result);
+
+// Creates new 4-channel 2D texture with f16 elements
+absl::Status CreateTexture2DRGBA16F(int width, int height, CLContext *context, Texture2D *result);
+
+absl::Status CreateTexture2DRGBA(DataType type, int width, int height, CLContext *context,
+ Texture2D *result);
+
+absl::Status CreateTexture2DRGBA(DataType type, int width, int height, void *data,
+ CLContext *context, Texture2D *result);
+
+template <typename T>
+absl::Status Texture2D::WriteData(CLCommandQueue *queue, const absl::Span<T> data)
+{
+ const int element_size = ChannelTypeToSizeInBytes(channel_type_);
+ if (sizeof(T) % element_size != 0)
+ {
+ return absl::InvalidArgumentError(
+ "Template type T has not suitable element type for created texture.");
+ }
+ if (4 * width_ * height_ * element_size != data.size() * sizeof(T))
+ {
+ return absl::InvalidArgumentError(
+ "absl::Span<T> data size is different from texture allocated size.");
+ }
+
+ RETURN_IF_ERROR(queue->EnqueueWriteImage(texture_, int3(width_, height_, 1), data.data()));
+
+ return absl::OkStatus();
+}
+
+template <typename T>
+absl::Status Texture2D::ReadData(CLCommandQueue *queue, std::vector<T> *result) const
+{
+ const int element_size = ChannelTypeToSizeInBytes(channel_type_);
+ if (sizeof(T) != element_size)
+ {
+ return absl::InvalidArgumentError("Pixel format is different.");
+ }
+
+ const int elements_count = width_ * height_ * 4;
+ result->resize(elements_count);
+
+ return queue->EnqueueReadImage(texture_, int3(width_, height_, 1), result->data());
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_TEXTURE2D_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_TYPES_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_TYPES_H__
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// TODO(akulik): make these types Google-style compliant.
+
+template <typename T> struct alignas(sizeof(T)) Vec4
+{
+ union {
+ struct
+ {
+ T x, y, z, w;
+ };
+ std::array<T, 4> data_;
+ };
+
+ Vec4() : Vec4(T(0.0f)) {}
+
+ template <typename S> Vec4(S x_, S y_, S z_, S w_) : x(x_), y(y_), z(z_), w(w_) {}
+ explicit Vec4(T v) : x(v), y(v), z(v), w(v) {}
+
+ template <typename S> explicit Vec4(S v) : x(v), y(v), z(v), w(v) {}
+
+ Vec4(const Vec4 &f) : x(f.x), y(f.y), z(f.z), w(f.w) {}
+
+ template <typename S> Vec4(const Vec4<S> &f) : x(f.x), y(f.y), z(f.z), w(f.w) {}
+
+ Vec4 &operator=(const Vec4 &other)
+ {
+ x = other.x;
+ y = other.y;
+ z = other.z;
+ w = other.w;
+ return *this;
+ }
+
+ static constexpr int size() { return 4; }
+
+ T &operator[](size_t n) { return data_[n]; }
+ T operator[](size_t n) const { return data_[n]; }
+
+ bool operator==(const Vec4 &value) const
+ {
+ return data_[0] == value[0] && data_[1] == value[1] && data_[2] == value[2] &&
+ data_[3] == value[3];
+ }
+ bool operator!=(const Vec4 &value) const { return !(this->operator==(value)); }
+};
+
+template <typename T> struct alignas(sizeof(T)) Vec3
+{
+ union {
+ struct
+ {
+ T x, y, z;
+ };
+ std::array<T, 3> data_;
+ };
+
+ Vec3() : Vec3(T(0.0f)) {}
+
+ template <typename S> constexpr Vec3(S x_, S y_, S z_) : x(x_), y(y_), z(z_) {}
+ explicit Vec3(T v) : x(v), y(v), z(v) {}
+
+ template <typename S> explicit Vec3(S v) : x(v), y(v), z(v) {}
+
+ Vec3(const Vec3 &f) : x(f.x), y(f.y), z(f.z) {}
+
+ template <typename S> Vec3(const Vec3<S> &f) : x(f.x), y(f.y), z(f.z) {}
+
+ Vec3 &operator=(const Vec3 &other)
+ {
+ x = other.x;
+ y = other.y;
+ z = other.z;
+ return *this;
+ }
+
+ static constexpr int size() { return 3; }
+
+ T &operator[](size_t n) { return data_[n]; }
+ T operator[](size_t n) const { return data_[n]; }
+ bool operator==(const Vec3 &value) const
+ {
+ return data_[0] == value[0] && data_[1] == value[1] && data_[2] == value[2];
+ }
+ bool operator!=(const Vec3 &value) const { return !(this->operator==(value)); }
+};
+
+template <typename T> struct alignas(sizeof(T)) Vec2
+{
+ union {
+ struct
+ {
+ T x, y;
+ };
+ std::array<T, 2> data_;
+ };
+
+ Vec2() : Vec2(T(0.0f)) {}
+
+ template <typename S> Vec2(S x_, S y_) : x(x_), y(y_) {}
+ explicit Vec2(T v) : x(v), y(v) {}
+
+ template <typename S> explicit Vec2(S v) : x(v), y(v) {}
+
+ Vec2(const Vec2 &f) : x(f.x), y(f.y) {}
+
+ template <typename S> Vec2(const Vec2<S> &f) : x(f.x), y(f.y) {}
+
+ Vec2 &operator=(const Vec2 &other)
+ {
+ x = other.x;
+ y = other.y;
+ return *this;
+ }
+
+ bool operator==(const Vec2 &value) const { return data_[0] == value[0] && data_[1] == value[1]; }
+
+ bool operator!=(const Vec2 &value) const { return !(this->operator==(value)); }
+
+ static constexpr int size() { return 2; }
+
+ T &operator[](size_t n) { return data_[n]; }
+ T operator[](size_t n) const { return data_[n]; }
+};
+
+using float2 = Vec2<float>;
+using byte2 = Vec2<int8_t>;
+using ubyte2 = Vec2<uint8_t>;
+using short2 = Vec2<int16_t>;
+using ushort2 = Vec2<uint16_t>;
+using int2 = Vec2<int32_t>;
+using uint2 = Vec2<uint32_t>;
+
+using float3 = Vec3<float>;
+using byte3 = Vec3<int8_t>;
+using ubyte3 = Vec3<uint8_t>;
+using short3 = Vec3<int16_t>;
+using ushort3 = Vec3<uint16_t>;
+using int3 = Vec3<int32_t>;
+using uint3 = Vec3<uint32_t>;
+
+using float4 = Vec4<float>;
+using byte4 = Vec4<int8_t>;
+using ubyte4 = Vec4<uint8_t>;
+using short4 = Vec4<int16_t>;
+using ushort4 = Vec4<uint16_t>;
+using int4 = Vec4<int32_t>;
+using uint4 = Vec4<uint32_t>;
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_TYPES_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Util.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/substitute.h"
+#include "Status.h"
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::string CLErrorCodeToString(cl_int error_code)
+{
+ switch (error_code)
+ {
+ case CL_SUCCESS:
+ return "Success";
+ case CL_DEVICE_NOT_FOUND:
+ return "Device not found";
+ case CL_DEVICE_NOT_AVAILABLE:
+ return "Device not available";
+ case CL_COMPILER_NOT_AVAILABLE:
+ return "Compiler not available";
+ case CL_MEM_OBJECT_ALLOCATION_FAILURE:
+ return "Memory object allocation failure";
+ case CL_OUT_OF_RESOURCES:
+ return "Out of resources";
+ case CL_OUT_OF_HOST_MEMORY:
+ return "Out of host memory";
+ case CL_PROFILING_INFO_NOT_AVAILABLE:
+ return "Profiling information not available";
+ case CL_MEM_COPY_OVERLAP:
+ return "Memory copy overlap";
+ case CL_IMAGE_FORMAT_MISMATCH:
+ return "Image format mismatch";
+ case CL_IMAGE_FORMAT_NOT_SUPPORTED:
+ return "Image format not supported";
+ case CL_BUILD_PROGRAM_FAILURE:
+ return "Build program failure";
+ case CL_MAP_FAILURE:
+ return "Mapping failure";
+ case CL_MISALIGNED_SUB_BUFFER_OFFSET:
+ return "Misaligned sub-buffer offset";
+ case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST:
+ return "Execution status error for events in wait list";
+ case CL_COMPILE_PROGRAM_FAILURE:
+ return "Compile program failure";
+ case CL_LINKER_NOT_AVAILABLE:
+ return "Linker not available";
+ case CL_LINK_PROGRAM_FAILURE:
+ return "Link program failure";
+ case CL_DEVICE_PARTITION_FAILED:
+ return "Device partition failed";
+ case CL_KERNEL_ARG_INFO_NOT_AVAILABLE:
+ return "Kernel argument information not available";
+
+ case CL_INVALID_VALUE:
+ return "Invalid value";
+ case CL_INVALID_DEVICE_TYPE:
+ return "Invalid device type";
+ case CL_INVALID_PLATFORM:
+ return "Invalid platform";
+ case CL_INVALID_DEVICE:
+ return "Invalid device";
+ case CL_INVALID_CONTEXT:
+ return "Invalid context";
+ case CL_INVALID_QUEUE_PROPERTIES:
+ return "Invalid queue properties";
+ case CL_INVALID_COMMAND_QUEUE:
+ return "Invalid command queue";
+ case CL_INVALID_HOST_PTR:
+ return "Invalid host pointer";
+ case CL_INVALID_MEM_OBJECT:
+ return "Invalid memory object";
+ case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR:
+ return "Invalid image format descriptor";
+ case CL_INVALID_IMAGE_SIZE:
+ return "Invalid image size";
+ case CL_INVALID_SAMPLER:
+ return "Invalid sampler";
+ case CL_INVALID_BINARY:
+ return "Invalid binary";
+ case CL_INVALID_BUILD_OPTIONS:
+ return "Invalid build options";
+ case CL_INVALID_PROGRAM:
+ return "Invalid program";
+ case CL_INVALID_PROGRAM_EXECUTABLE:
+ return "Invalid program executable";
+ case CL_INVALID_KERNEL_NAME:
+ return "Invalid kernel name";
+ case CL_INVALID_KERNEL_DEFINITION:
+ return "Invalid kernel definition";
+ case CL_INVALID_KERNEL:
+ return "Invalid kernel";
+ case CL_INVALID_ARG_INDEX:
+ return "Invalid argument index";
+ case CL_INVALID_ARG_VALUE:
+ return "Invalid argument value";
+ case CL_INVALID_ARG_SIZE:
+ return "Invalid argument size";
+ case CL_INVALID_KERNEL_ARGS:
+ return "Invalid kernel arguments";
+ case CL_INVALID_WORK_DIMENSION:
+ return "Invalid work dimension";
+ case CL_INVALID_WORK_GROUP_SIZE:
+ return "Invalid work group size";
+ case CL_INVALID_WORK_ITEM_SIZE:
+ return "Invalid work item size";
+ case CL_INVALID_GLOBAL_OFFSET:
+ return "Invalid global offset";
+ case CL_INVALID_EVENT_WAIT_LIST:
+ return "Invalid event wait list";
+ case CL_INVALID_EVENT:
+ return "Invalid event";
+ case CL_INVALID_OPERATION:
+ return "Invalid operation";
+ case CL_INVALID_GL_OBJECT:
+ return "Invalid GL object";
+ case CL_INVALID_BUFFER_SIZE:
+ return "Invalid buffer size";
+ case CL_INVALID_MIP_LEVEL:
+ return "Invalid mip-level";
+ case CL_INVALID_GLOBAL_WORK_SIZE:
+ return "Invalid global work size";
+ case CL_INVALID_PROPERTY:
+ return "Invalid property";
+ case CL_INVALID_IMAGE_DESCRIPTOR:
+ return "Invalid image descriptor";
+ case CL_INVALID_COMPILER_OPTIONS:
+ return "Invalid compiler options";
+ case CL_INVALID_LINKER_OPTIONS:
+ return "Invalid linker options";
+ case CL_INVALID_DEVICE_PARTITION_COUNT:
+ return "Invalid device partition count";
+ case CL_INVALID_PIPE_SIZE:
+ return "Invalid pipe size";
+ case CL_INVALID_DEVICE_QUEUE:
+ return "Invalid device queue";
+ case CL_INVALID_GL_SHAREGROUP_REFERENCE_KHR:
+ return "Invalid GL sharegroup reference KHR";
+
+ default:
+ return "Unknown OpenCL";
+ }
+}
+
+int ChannelTypeToSizeInBytes(cl_channel_type type)
+{
+ switch (type)
+ {
+ case CL_FLOAT:
+ return 4;
+ default:
+ return 0;
+ }
+}
+
+absl::Status CreateCLBuffer(cl_context context, int size_in_bytes, bool read_only, void *data,
+ cl_mem *result)
+{
+ cl_mem_flags flags = read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE;
+ if (data)
+ {
+ flags |= CL_MEM_COPY_HOST_PTR;
+ }
+ cl_int error_code;
+ *result = clCreateBuffer(context, flags, size_in_bytes, data, &error_code);
+ if (!*result)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to allocate device memory (clCreateBuffer): ",
+ CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+cl_channel_type DataTypeToChannelType(DataType type, bool normalized)
+{
+ switch (type)
+ {
+ case DataType::FLOAT32:
+ return CL_FLOAT;
+ case DataType::INT8:
+ return normalized ? CL_SNORM_INT8 : CL_SIGNED_INT8;
+ case DataType::UINT8:
+ return normalized ? CL_UNORM_INT8 : CL_UNSIGNED_INT8;
+ case DataType::INT16:
+ return normalized ? CL_SNORM_INT16 : CL_SIGNED_INT16;
+ case DataType::UINT16:
+ return normalized ? CL_UNORM_INT16 : CL_UNSIGNED_INT16;
+ case DataType::INT32:
+ return CL_SIGNED_INT32;
+ case DataType::UINT32:
+ return CL_UNSIGNED_INT32;
+ default:
+ return CL_FLOAT;
+ }
+}
+
+absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
+ cl_channel_type channel_type, void *data, cl_mem *result)
+{
+ cl_image_desc desc;
+ desc.image_type = CL_MEM_OBJECT_IMAGE2D;
+ desc.image_width = width;
+ desc.image_height = height;
+ desc.image_depth = 0;
+ desc.image_row_pitch = 0;
+ desc.image_slice_pitch = 0;
+ desc.num_mip_levels = 0;
+ desc.num_samples = 0;
+ desc.buffer = nullptr;
+
+ cl_image_format format;
+ format.image_channel_order = CL_RGBA;
+ format.image_channel_data_type = channel_type;
+
+ cl_mem_flags flags = CL_MEM_READ_WRITE;
+ if (data)
+ {
+ flags |= CL_MEM_COPY_HOST_PTR;
+ }
+
+ cl_int error_code;
+ *result = CreateImage2DLegacy(context, flags, &format, &desc, data, &error_code);
+ if (error_code != CL_SUCCESS)
+ {
+ return absl::UnknownError(absl::StrCat("Failed to create 2D texture (clCreateImage): ",
+ CLErrorCodeToString(error_code)));
+ }
+ return absl::OkStatus();
+}
+
+std::string GetXStrideCorrected(const std::string &src_x, const std::string &batch_size,
+ const std::string &stride_x, const std::string &padding_x)
+{
+ // TODO(sorokin) check perf and optimize with floor() if needed
+ // int p0 = src_x / batch_size;\n";
+ // int b0 = src_x % batch_size;\n";
+ // return p0 * stride_x * batch_size + b0 + padding_x;\n";
+ return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x, batch_size, stride_x,
+ padding_x);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_UTIL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_UTIL_H__
+
+#include <string>
+
+#include "absl/types/span.h"
+#include "OpenclWrapper.h"
+#include "DataType.h"
+#include "InternalTensor.h"
+#include "Status.h"
+#include "Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrected(const std::string &src_x, const std::string &batch_size,
+ const std::string &stride_x, const std::string &padding_x);
+
+// @param n must be non negative
+// @param divisor must be greater than zero
+template <typename T, typename N> T DivideRoundUp(T n, N divisor)
+{
+ const T div = static_cast<T>(divisor);
+ const T q = n / div;
+ return n % div == 0 ? q : q + 1;
+}
+
+template <> inline uint3 DivideRoundUp(uint3 n, uint3 divisor)
+{
+ return uint3(DivideRoundUp(n.x, divisor.x), DivideRoundUp(n.y, divisor.y),
+ DivideRoundUp(n.z, divisor.z));
+}
+
+// @param number or its components must be greater than zero
+// @param n must be greater than zero
+template <typename T, typename N> T AlignByN(T number, N n) { return DivideRoundUp(number, n) * n; }
+
+std::string CLErrorCodeToString(cl_int error_code);
+
+int ChannelTypeToSizeInBytes(cl_channel_type type);
+
+template <DataType S, typename T>
+void CopyLinearFLT4(const InternalTensor<Linear, S> &src, absl::Span<T> dst)
+{
+ const int dst_depth = dst.size();
+ for (int d = 0; d < dst_depth; ++d)
+ {
+ T val;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int dst_ch = d * 4 + i;
+ val[i] = dst_ch >= src.shape.v ? 0.0f : src.data[dst_ch];
+ }
+ dst[d] = val;
+ }
+}
+
+absl::Status CreateCLBuffer(cl_context context, int size_in_bytes, bool read_only, void *data,
+ cl_mem *result);
+
+cl_channel_type DataTypeToChannelType(DataType type, bool normalized = false);
+absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
+ cl_channel_type channel_type, void *data, cl_mem *result);
+
+template <DataType S, typename T>
+void RearrangeWeightsToOHWIOGroupI4O4(const InternalTensor<OHWI, S> &weights, int out_group_size,
+ absl::Span<T> dst)
+{
+ const int dst_slices = DivideRoundUp(weights.shape.o, 4);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+ const int dst_groups = DivideRoundUp(dst_slices, out_group_size);
+
+ int counter = 0;
+ for (int d = 0; d < dst_groups; ++d)
+ {
+ for (int y = 0; y < weights.shape.h; ++y)
+ {
+ for (int x = 0; x < weights.shape.w; ++x)
+ {
+ for (int s = 0; s < src_slices; ++s)
+ {
+ for (int d_group = 0; d_group < out_group_size; ++d_group)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ T filter;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int s_ch = s * 4 + j;
+ const int d_ch = (d * out_group_size + d_group) * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o)
+ {
+ const int f_index = weights.shape.LinearIndex({d_ch, y, x, s_ch});
+ filter[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template <DataType S, typename T>
+void RearrangeWeightsToODHWIOGroupI4O4(const InternalTensor<OHWDI, S> &weights, int out_group_size,
+ absl::Span<T> dst)
+{
+ const int dst_slices = DivideRoundUp(weights.shape.o, 4);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+ const int dst_groups = DivideRoundUp(dst_slices, out_group_size);
+
+ int counter = 0;
+ for (int d = 0; d < dst_groups; ++d)
+ {
+ for (int z = 0; z < weights.shape.d; ++z)
+ {
+ for (int y = 0; y < weights.shape.h; ++y)
+ {
+ for (int x = 0; x < weights.shape.w; ++x)
+ {
+ for (int s = 0; s < src_slices; ++s)
+ {
+ for (int d_group = 0; d_group < out_group_size; ++d_group)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ T filter;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int s_ch = s * 4 + j;
+ const int d_ch = (d * out_group_size + d_group) * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o)
+ {
+ const int f_index = weights.shape.LinearIndex({d_ch, y, x, z, s_ch});
+ filter[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template <DataType S, typename T>
+void RearrangeWeightsToI4HWIOOGroupO4(const InternalTensor<OHWI, S> &weights, int out_group_size,
+ absl::Span<T> dst)
+{
+ const int dst_slices = DivideRoundUp(weights.shape.o, 4);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+ const int dst_groups = DivideRoundUp(dst_slices, out_group_size);
+
+ int counter = 0;
+ for (int j = 0; j < 4; ++j)
+ {
+ for (int y = 0; y < weights.shape.h; ++y)
+ {
+ for (int x = 0; x < weights.shape.w; ++x)
+ {
+ for (int s = 0; s < src_slices; ++s)
+ {
+ for (int d = 0; d < dst_groups; ++d)
+ {
+ for (int d_group = 0; d_group < out_group_size; ++d_group)
+ {
+ T filter;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int s_ch = s * 4 + j;
+ const int d_ch = (d * out_group_size + d_group) * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o)
+ {
+ const int f_index = weights.shape.LinearIndex({d_ch, y, x, s_ch});
+ filter[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template <DataType S, typename T>
+void RearrangeWeightsToI4DHWIOOGroupO4(const InternalTensor<OHWDI, S> &weights, int out_group_size,
+ absl::Span<T> dst)
+{
+ const int dst_slices = DivideRoundUp(weights.shape.o, 4);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+ const int dst_groups = DivideRoundUp(dst_slices, out_group_size);
+
+ int counter = 0;
+ for (int j = 0; j < 4; ++j)
+ {
+ for (int z = 0; z < weights.shape.d; ++z)
+ {
+ for (int y = 0; y < weights.shape.h; ++y)
+ {
+ for (int x = 0; x < weights.shape.w; ++x)
+ {
+ for (int s = 0; s < src_slices; ++s)
+ {
+ for (int d = 0; d < dst_groups; ++d)
+ {
+ for (int d_group = 0; d_group < out_group_size; ++d_group)
+ {
+ T filter;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int s_ch = s * 4 + j;
+ const int d_ch = (d * out_group_size + d_group) * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o)
+ {
+ const int f_index = weights.shape.LinearIndex({d_ch, y, x, z, s_ch});
+ filter[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_UTIL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "open_cl/WinogradUtil.h"
+
+#include <cmath>
+#include <vector>
+
+#include "open_cl/DataType.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace
+{
+// Matrices for Winograd trasformations were computed with the method described
+// here https://openreview.net/pdf?id=H1ZaRZVKg
+std::vector<float> GetTransposedMatrixForWinograd(int width, int height)
+{
+ const float kDelta = std::sqrt(2.0f) / 2.0f;
+ std::vector<float> px(width);
+
+ px[0] = 0.0f;
+ const int points_count = (width - 1) / 2;
+ for (int i = 0; i < points_count; ++i)
+ {
+ px[i * 2 + 1] = kDelta * (i + 1.0f);
+ px[i * 2 + 2] = -kDelta * (i + 1.0f);
+ }
+ px[width - 1] = 1.0f;
+
+ std::vector<float> py(width, 1.0f);
+ py[width - 1] = 0.0f;
+
+ std::vector<float> result(height * width);
+ for (int y = 0; y < width; ++y)
+ {
+ for (int x = 0; x < height; ++x)
+ {
+ result[x * width + y] = std::pow(px[y], 1.0f * x) * std::pow(py[y], (height - 1.0f) - x);
+ }
+ }
+ return result;
+}
+
+std::vector<float> GetInversedMatrixForWinograd(int rank)
+{
+ auto matrix = GetTransposedMatrixForWinograd(rank, rank);
+ std::vector<float> inverted(rank * rank, 0.0f);
+ for (int i = 0; i < rank; ++i)
+ {
+ inverted[i * rank + i] = 1.0f;
+ }
+
+ for (int i = 1; i < rank - 1; ++i)
+ {
+ float inv_t = 1.0f / matrix[i * rank + i];
+ for (int x = i; x < rank; ++x)
+ {
+ matrix[i * rank + x] *= inv_t;
+ }
+ for (int x = 0; x < rank; ++x)
+ {
+ inverted[i * rank + x] *= inv_t;
+ }
+
+ for (int y = 0; y < rank; ++y)
+ {
+ if (y == i)
+ continue;
+ float t = matrix[y * rank + i];
+ for (int x = i; x < rank; ++x)
+ {
+ matrix[y * rank + x] -= t * matrix[i * rank + x];
+ }
+ for (int x = 0; x < rank; ++x)
+ {
+ inverted[y * rank + x] -= t * inverted[i * rank + x];
+ }
+ }
+ }
+
+ return inverted;
+}
+
+std::vector<float> Multiply(const std::vector<float> &a_mat, const std::vector<float> &b_mat, int m,
+ int n, int k)
+{
+ std::vector<float> result(m * k);
+ for (int y = 0; y < m; ++y)
+ {
+ for (int x = 0; x < k; ++x)
+ {
+ float sum = 0.0f;
+ for (int i = 0; i < n; ++i)
+ {
+ sum += a_mat[y * n + i] * b_mat[i * k + x];
+ }
+ result[y * k + x] = sum;
+ }
+ }
+ return result;
+}
+} // namespace
+
+std::vector<float> AtMatrixForWinograd4x4To6x6() { return GetTransposedMatrixForWinograd(6, 4); }
+
+std::vector<float> BtMatrixForWinograd4x4To6x6() { return GetInversedMatrixForWinograd(6); }
+
+void RearrangeWeightsToWinograd4x4To6x6Weights(
+ const gpu_cl::InternalTensor<gpu_cl::OHWI, gpu_cl::DataType::FLOAT32> &src_weights,
+ gpu_cl::InternalTensor<gpu_cl::OHWI, gpu_cl::DataType::FLOAT32> *dst_weights)
+{
+ gpu_cl::OHWI dst_shape;
+ dst_shape.o = src_weights.shape.o;
+ dst_shape.h = 6;
+ dst_shape.w = 6;
+ dst_shape.i = src_weights.shape.i;
+ dst_weights->shape = dst_shape;
+ dst_weights->data.resize(dst_shape.DimensionsProduct());
+
+ auto gt_mat = GetTransposedMatrixForWinograd(6, 3);
+ std::vector<float> g_mat(gt_mat.size());
+ for (int y = 0; y < 3; ++y)
+ {
+ for (int x = 0; x < 6; ++x)
+ {
+ g_mat[x * 3 + y] = gt_mat[y * 6 + x];
+ }
+ }
+
+ for (int d = 0; d < src_weights.shape.o; ++d)
+ {
+ for (int s = 0; s < src_weights.shape.i; ++s)
+ {
+ std::vector<float> in_vals(9);
+ for (int y = 0; y < 3; ++y)
+ {
+ for (int x = 0; x < 3; ++x)
+ {
+ const int f_index = src_weights.shape.LinearIndex({d, y, x, s});
+ in_vals[y * 3 + x] = src_weights.data[f_index];
+ }
+ }
+
+ auto temp_vals = Multiply(g_mat, in_vals, 6, 3, 3);
+ auto out_vals = Multiply(temp_vals, gt_mat, 6, 3, 6);
+ for (int y = 0; y < 6; ++y)
+ {
+ for (int x = 0; x < 6; ++x)
+ {
+ const int f_index = dst_shape.LinearIndex({d, y, x, s});
+ dst_weights->data[f_index] = out_vals[y * 6 + x];
+ }
+ }
+ }
+ }
+}
+
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_WINOGRAD_UTIL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_WINOGRAD_UTIL_H__
+
+#include <vector>
+
+#include "open_cl/DataType.h"
+#include "open_cl/Shape.h"
+#include "open_cl/InternalTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+
+// Matrices for Winograd trasformations received with method described here
+// https://openreview.net/pdf?id=H1ZaRZVKg
+
+// returns A transposed matrix(6 * 4) as array (24 values) for Winograd4x4To6x6
+std::vector<float> AtMatrixForWinograd4x4To6x6();
+
+// returns B transposed matrix(6 * 6) as array (36 values) for Winograd4x4To6x6
+std::vector<float> BtMatrixForWinograd4x4To6x6();
+
+void RearrangeWeightsToWinograd4x4To6x6Weights(
+ const gpu_cl::InternalTensor<gpu_cl::OHWI, gpu_cl::DataType::FLOAT32> &src_weights,
+ gpu_cl::InternalTensor<gpu_cl::OHWI, gpu_cl::DataType::FLOAT32> *dst_weights);
+
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_WINOGRAD_UTIL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WorkgroupSelection.h"
+
+#include <math.h>
+
+#include <set>
+#include <vector>
+
+#include "Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+namespace
+{
+
+template <typename T>
+void AddCornerCases(const T &grid, int max_work_group_total_size, const T &max_work_group_sizes,
+ WorkGroupSizeAlignment x_alignment, WorkGroupSizeAlignment y_alignment,
+ WorkGroupSizeAlignment z_alignment, std::vector<T> *work_groups)
+{
+ for (int x = 1; x <= 4; ++x)
+ {
+ for (int y = 1; y <= 4; ++y)
+ {
+ for (int z = 1; z <= 4; ++z)
+ {
+ u_int32_t wg_x = DivideRoundUp(grid.x, x);
+ u_int32_t wg_y = DivideRoundUp(grid.y, y);
+ u_int32_t wg_z = DivideRoundUp(grid.z, z);
+ if (wg_x > static_cast<u_int32_t>(max_work_group_sizes.x) ||
+ wg_y > static_cast<u_int32_t>(max_work_group_sizes.y) ||
+ wg_z > static_cast<u_int32_t>(max_work_group_sizes.z) ||
+ wg_x * wg_y * wg_z > static_cast<u_int32_t>(max_work_group_total_size))
+ {
+ continue;
+ }
+ if (x_alignment == WorkGroupSizeAlignment::PRECISE && grid.x % wg_x != 0)
+ {
+ continue;
+ }
+ if (y_alignment == WorkGroupSizeAlignment::PRECISE && grid.y % wg_y != 0)
+ {
+ continue;
+ }
+ if (z_alignment == WorkGroupSizeAlignment::PRECISE && grid.z % wg_z != 0)
+ {
+ continue;
+ }
+ work_groups->push_back({wg_x, wg_y, wg_z});
+ }
+ }
+ }
+
+ // this will add at least {1, 1, 1} always.
+ for (u_int32_t x = 1; x <= 4; ++x)
+ {
+ for (u_int32_t y = 1; y <= 4; ++y)
+ {
+ for (u_int32_t z = 1; z <= 4; ++z)
+ {
+ if (x > static_cast<u_int32_t>(max_work_group_sizes.x) ||
+ y > static_cast<u_int32_t>(max_work_group_sizes.y) ||
+ z > static_cast<u_int32_t>(max_work_group_sizes.z) ||
+ x * y * z > static_cast<u_int32_t>(max_work_group_total_size))
+ {
+ continue;
+ }
+ if (x_alignment == WorkGroupSizeAlignment::PRECISE && grid.x % x != 0)
+ {
+ continue;
+ }
+ if (y_alignment == WorkGroupSizeAlignment::PRECISE && grid.y % y != 0)
+ {
+ continue;
+ }
+ if (z_alignment == WorkGroupSizeAlignment::PRECISE && grid.z % z != 0)
+ {
+ continue;
+ }
+ work_groups->push_back({x, y, z});
+ }
+ }
+ }
+}
+
+std::vector<int> GetDivisors(int number)
+{
+ const int max_divisor = static_cast<int>(sqrt(number));
+ std::vector<int> divisors;
+ // we don't know the number of dividers, so it is just heuristic.
+ divisors.reserve(max_divisor / 3 + 1);
+ for (int i = 1; i <= max_divisor; ++i)
+ {
+ const int d = number / i;
+ if (i * d == number)
+ {
+ divisors.push_back(i);
+ if (d != i)
+ {
+ divisors.push_back(d);
+ }
+ }
+ }
+ return divisors;
+}
+
+std::vector<int> GetDivisorsForRange(int number, int range)
+{
+ const int last_number = number + range;
+ const int max_divisor = static_cast<int>(sqrt(last_number));
+ std::set<int> divisors;
+ for (int i = 1; i <= max_divisor; ++i)
+ {
+ const int reminder = number % i;
+ // iterate through numbers that divisible by i in our range;
+ const int first_number = number + (i - reminder) % i;
+ if (first_number <= last_number)
+ {
+ divisors.insert(i);
+ }
+ for (int j = first_number; j <= last_number; j += i)
+ {
+ const int d = j / i;
+ if (d != i)
+ {
+ divisors.insert(d);
+ }
+ }
+ }
+ return std::vector<int>(divisors.begin(), divisors.end());
+}
+
+} // namespace
+
+std::vector<int> GetPossibleSizes(int number, WorkGroupSizeAlignment z_alignment)
+{
+ if (z_alignment == WorkGroupSizeAlignment::PRECISE)
+ {
+ // we will use for potential sizes, sizes that cover grid precisely
+ // work group size * k (k is integer) == grid_size
+ return GetDivisors(number);
+ }
+ else
+ {
+ // when we chose work group size we can use work group size that
+ // work group size * k (k is integer) != grid_size (slightly bigger)
+ // so in this heuristic we trying to find potential size, that satisfies
+ // to this : work group size * k (k is integer) <= grid_size + 5
+ // and this : work group size * k (k is integer) >= grid_size
+ return GetDivisorsForRange(number, 5);
+ }
+}
+
+template <typename T>
+std::vector<T>
+GenerateWorkGroupSizes(const T &grid, int min_work_group_total_size, int max_work_group_total_size,
+ const T &max_work_group_sizes, WorkGroupSizeAlignment x_alignment,
+ WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment)
+{
+ std::vector<T> work_groups;
+ work_groups.reserve(64);
+
+ std::vector<int> sizes_x = GetPossibleSizes(grid.x, x_alignment);
+ std::vector<int> sizes_y = GetPossibleSizes(grid.y, y_alignment);
+ std::vector<int> sizes_z = GetPossibleSizes(grid.z, z_alignment);
+
+ for (auto x : sizes_x)
+ {
+ if (static_cast<int>(x) > static_cast<int>(max_work_group_sizes.x))
+ continue;
+ for (auto y : sizes_y)
+ {
+ if (static_cast<int>(y) > static_cast<int>(max_work_group_sizes.y))
+ continue;
+ for (auto z : sizes_z)
+ {
+ if (static_cast<int>(z) > static_cast<int>(max_work_group_sizes.z))
+ continue;
+ const int work_group_size = x * y * z;
+ if (work_group_size < min_work_group_total_size ||
+ work_group_size > max_work_group_total_size)
+ continue;
+ work_groups.push_back({x, y, z});
+ }
+ }
+ }
+
+ return work_groups;
+}
+
+// Specializations of GenerateWorkGroupSizes for int3 and uint3
+
+template std::vector<int3> GenerateWorkGroupSizes(const int3 &grid, int min_work_group_total_size,
+ int max_work_group_total_size,
+ const int3 &max_work_group_sizes,
+ WorkGroupSizeAlignment x_alignment,
+ WorkGroupSizeAlignment y_alignment,
+ WorkGroupSizeAlignment z_alignment);
+
+template std::vector<uint3> GenerateWorkGroupSizes(const uint3 &grid, int min_work_group_total_size,
+ int max_work_group_total_size,
+ const uint3 &max_work_group_sizes,
+ WorkGroupSizeAlignment x_alignment,
+ WorkGroupSizeAlignment y_alignment,
+ WorkGroupSizeAlignment z_alignment);
+
+template <typename T>
+void GenerateWorkGroupSizesAlignedToGrid(const T &grid, const T &max_work_group_size,
+ const int max_work_group_invocations,
+ std::vector<T> *work_groups)
+{
+ auto alignment = WorkGroupSizeAlignment::PRECISE;
+ *work_groups =
+ GenerateWorkGroupSizes<T>(grid, /*min_work_group_total_size = */ 32, max_work_group_invocations,
+ max_work_group_size, alignment, alignment, alignment);
+ // If the grid parameter too small, method below cannot generate workgroups.
+ if (work_groups->empty())
+ {
+ AddCornerCases(grid, max_work_group_invocations, max_work_group_size, alignment, alignment,
+ alignment, work_groups);
+ }
+}
+
+// Specializations of GenerateWorkGroupSizesAlignedToGrid for int3 and uint3
+
+template void GenerateWorkGroupSizesAlignedToGrid(const int3 &grid, const int3 &max_work_group_size,
+ const int max_work_group_invocations,
+ std::vector<int3> *work_groups);
+
+template void GenerateWorkGroupSizesAlignedToGrid(const uint3 &grid,
+ const uint3 &max_work_group_size,
+ const int max_work_group_invocations,
+ std::vector<uint3> *work_groups);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_WORK_GROUP_SELECTION_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_WORK_GROUP_SELECTION_H__
+
+#include <vector>
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// PRECISE assume that WorkGroupSize * k = GridSize;
+// NO_ALIGNMENT no restrictions;
+// We need PRECISE when we don't have check in kernel for boundaries
+// If we have the check, we can use PRECISE or NO_ALIGNMENT as well.
+enum class WorkGroupSizeAlignment
+{
+ PRECISE,
+ NO_ALIGNMENT
+};
+
+std::vector<int> GetPossibleSizes(int number, WorkGroupSizeAlignment z_alignment);
+
+// Specializations exist for int3 and uint3 in the .cc file
+
+template <typename T>
+std::vector<T>
+GenerateWorkGroupSizes(const T &grid, int min_work_group_total_size, int max_work_group_total_size,
+ const T &max_work_group_sizes, WorkGroupSizeAlignment x_alignment,
+ WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment);
+
+template <typename T>
+void GenerateWorkGroupSizesAlignedToGrid(const T &grid, const T &max_work_group_size,
+ const int max_work_group_invocations,
+ std::vector<T> *work_groups);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_WORK_GROUP_SELECTION_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Add.h"
+
+#include <cstring>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "Util.h"
+#include "open_cl/Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreateAdd(const OperationDef &definition, const std::vector<int> &channels,
+ int dst_channels)
+{
+ GPUOperation add(definition);
+ int dst_depth = DivideRoundUp(dst_channels, 4);
+ int src0_depth = DivideRoundUp(channels[0], 4);
+ add.elementwise_ = true;
+ add.linkable_ = dst_depth == src0_depth;
+ if (src0_depth < dst_depth)
+ {
+ add.check_src_channels_size_ = true;
+ }
+ for (uint32_t i = 1; i < definition.src_tensors.size(); ++i)
+ {
+ const std::string tensor_name = absl::StrCat("src_data_", i);
+ auto src_desc = definition.src_tensors[i];
+ if (definition.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ add.AddSrcTensor(tensor_name, src_desc);
+ add.code_ += "if (S_COORD < args." + tensor_name + ".Slices()) {\n";
+ add.code_ += " in_out_value += args." + tensor_name + ".Read(X_COORD, Y_COORD, S_COORD);\n";
+ add.code_ += "}\n";
+ }
+ return add;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_ADD_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_ADD_H__
+
+#include <string>
+#include <vector>
+
+#include "GpuOperation.h"
+#include "open_cl/Operations.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// Add operation supports not equal tensors on input (for possibility to
+// remove Padding operation with zeroes in channels dimension)
+GPUOperation CreateAdd(const OperationDef &definition, const std::vector<int> &channels,
+ int dst_channels);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_ADD_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "open_cl/kernels/ConvBuffer1x1.h"
+
+#include <array>
+#include <string>
+#include <utility>
+
+#include "open_cl/ClDevice.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/Precision.h"
+#include "open_cl/TensorType.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+// element_size must be 1, 2 or 4
+// 1 - is FLT4
+// 2 - is FLT8
+// 4 - is FLT16
+// This function generates code for arithmetic part of convolution
+std::string GetComputationPart(const int3 &block_size, int element_size,
+ CalculationsPrecision precision)
+{
+ const std::string hexes[16] = {"0", "1", "2", "3", "4", "5", "6", "7",
+ "8", "9", "a", "b", "c", "d", "e", "f"};
+ std::string c;
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string z_s = std::to_string(z);
+ c += " FLT16 W" + z_s + " = weights_cache[" + z_s + "];\n";
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ std::string s_index = std::to_string(y * block_size.x + x);
+ for (int e = 0; e < element_size; ++e)
+ {
+ std::string r_index = z_s + std::to_string(y) + std::to_string(x * element_size + e);
+ const std::string f0 = "W" + z_s + ".s0123";
+ const std::string f1 = "W" + z_s + ".s4567";
+ const std::string f2 = "W" + z_s + ".s89ab";
+ const std::string f3 = "W" + z_s + ".scdef";
+ switch (precision)
+ {
+ case CalculationsPrecision::F32:
+ case CalculationsPrecision::F16:
+ c += " r" + r_index + " += " + f0 + " * s" + s_index + ".s" + hexes[e * 4 + 0] +
+ ";\n";
+ c += " r" + r_index + " += " + f1 + " * s" + s_index + ".s" + hexes[e * 4 + 1] +
+ ";\n";
+ c += " r" + r_index + " += " + f2 + " * s" + s_index + ".s" + hexes[e * 4 + 2] +
+ ";\n";
+ c += " r" + r_index + " += " + f3 + " * s" + s_index + ".s" + hexes[e * 4 + 3] +
+ ";\n";
+ break;
+ case CalculationsPrecision::F32_F16:
+ c += " r" + r_index + " += convert_float4(" + f0 + " * s" + s_index + ".s" +
+ hexes[e * 4 + 0] + " + " + f1 + " * s" + s_index + ".s" + hexes[e * 4 + 1] +
+ " + " + f2 + " * s" + s_index + ".s" + hexes[e * 4 + 2] + " + " + f3 + " * s" +
+ s_index + ".s" + hexes[e * 4 + 3] + ");\n";
+ break;
+ }
+ }
+ }
+ }
+ }
+ return c;
+}
+
+ConvBuffer1x1::ConvParams GetBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition, const BHWC &shape, int,
+ int dst_depth)
+{
+ ConvBuffer1x1::ConvParams conv_params;
+ conv_params.element_size = 4;
+ conv_params.block_size = int3(1, 1, 1);
+ if (!device_info.IsMali())
+ {
+ return conv_params;
+ }
+ bool can_use_flt8 =
+ (shape.w * shape.b) % 2 == 0 && definition.precision != CalculationsPrecision::F32;
+ bool is_midgard = device_info.IsMali() && device_info.mali_info.IsMidgard();
+ if (is_midgard)
+ {
+ if (can_use_flt8)
+ {
+ conv_params.element_size = 8;
+ }
+ if (definition.precision == CalculationsPrecision::F16 || !can_use_flt8)
+ {
+ conv_params.block_size.x = 2;
+ }
+ return conv_params;
+ }
+
+ int task_size = shape.w * shape.b * shape.h * dst_depth;
+ int block_size = GetRecommendedBlockSizeForConv(device_info, definition.precision, task_size);
+
+ if (!can_use_flt8 && block_size > 4)
+ {
+ block_size = 4;
+ }
+
+ if (can_use_flt8 && block_size >= 2)
+ {
+ conv_params.element_size = 8;
+ block_size /= 2;
+ }
+ if (block_size == 4)
+ {
+ conv_params.block_size.x = 2;
+ if (definition.precision == CalculationsPrecision::F32 && dst_depth < 32)
+ {
+ conv_params.block_size.y = 2;
+ }
+ else
+ {
+ conv_params.block_size.z = 2;
+ }
+ }
+ else if (block_size == 2)
+ {
+ if (dst_depth >= 32)
+ {
+ conv_params.block_size.z = 2;
+ }
+ else
+ {
+ conv_params.block_size.x = 2;
+ }
+ }
+
+ return conv_params;
+}
+
+ConvBuffer1x1::ConvParams GetBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition, int, int)
+{
+ ConvBuffer1x1::ConvParams conv_params;
+ conv_params.element_size = 4;
+ conv_params.block_size = int3(1, 1, 1);
+ if (device_info.IsMali() && definition.precision == CalculationsPrecision::F16 &&
+ device_info.compute_units_count <= 4)
+ {
+ conv_params.block_size.x *= 2;
+ }
+ return conv_params;
+}
+
+} // namespace
+
+ConvBuffer1x1::ConvBuffer1x1(const OperationDef &definition, const ConvParams &conv_params)
+ : GPUOperation(definition), conv_params_(conv_params)
+{
+ code_ = GenerateConvBuffer1x1(definition_, conv_params_, &args_);
+ work_group_size_ = int3(2, 4, 1);
+}
+
+ConvBuffer1x1::ConvBuffer1x1(ConvBuffer1x1 &&operation)
+ : GPUOperation(std::move(operation)), conv_params_(std::move(operation.conv_params_))
+{
+}
+
+ConvBuffer1x1 &ConvBuffer1x1::operator=(ConvBuffer1x1 &&operation)
+{
+ if (this != &operation)
+ {
+ std::swap(conv_params_, operation.conv_params_);
+ GPUOperation::operator=(std::move(operation));
+ }
+ return *this;
+}
+
+std::string ConvBuffer1x1::GenerateConvBuffer1x1(const OperationDef &op_def,
+ const ConvBuffer1x1::ConvParams &conv_params,
+ Arguments *)
+{
+ auto src_desc = op_def.src_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ if (conv_params_.element_size == 8)
+ {
+ src_desc.SetStateVar("ElementsX2", "true");
+ }
+ else if (conv_params_.element_size == 16)
+ {
+ src_desc.SetStateVar("ElementsX4", "true");
+ }
+ AddSrcTensor("src_tensor", src_desc);
+ if (op_def.src_tensors.size() == 2)
+ {
+ // dynamic weights
+ BufferDescriptor desc;
+ desc.element_type = op_def.src_tensors[1].data_type;
+ desc.element_size = 16;
+ desc.memory_type = MemoryType::GLOBAL;
+ AddSrcBuffer("weights", desc);
+ }
+
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ AddDstTensor("dst_tensor", dst_desc);
+
+ std::string c = GetCommonDefines(op_def.precision);
+ switch (op_def.precision)
+ {
+ case CalculationsPrecision::F32:
+ c += "#define FLT8 float8\n";
+ c += "#define FLT16 float16\n";
+ break;
+ case CalculationsPrecision::F32_F16:
+ case CalculationsPrecision::F16:
+ c += "#define FLT8 half8\n";
+ c += "#define FLT16 half16\n";
+ break;
+ }
+
+ const int3 block_size = conv_params.block_size;
+ const int element_size = conv_params.element_size / 4;
+
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0) * " + std::to_string(block_size.x * element_size) + ";\n";
+ c += " int X_SRC = get_global_id(0) * " + std::to_string(block_size.x) + ";\n";
+ c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + ";\n";
+ c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + ";\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) return;\n";
+ if (conv_params.different_weights_for_height)
+ {
+ c += " __global FLT16* weights_cache = args.weights.GetPtr() + (Z * "
+ "args.src_tensor.Height() + "
+ "Y * " +
+ std::to_string(block_size.z) +
+ ") * "
+ "args.src_tensor.Slices();\n";
+ }
+ else
+ {
+ c += " __global FLT16* weights_cache = args.weights.GetPtr() + Z * "
+ "args.src_tensor.Slices();\n";
+ }
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string z_s = std::to_string(z);
+ c += " ACCUM_FLT4 bias_val_" + z_s + " = TO_ACCUM_TYPE(args.biases.Read(Z + " + z_s + "));\n";
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ for (int x = 0; x < block_size.x * element_size; ++x)
+ {
+ c += " ACCUM_FLT4 r" + z_s + std::to_string(y) + std::to_string(x) + " = bias_val_" + z_s +
+ ";\n";
+ }
+ }
+ }
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ std::string x_s = std::to_string(x);
+ c += " int xc" + x_s + " = min(X_SRC + " + std::to_string(x) +
+ ", args.src_tensor.Width() - 1);\n";
+ }
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ std::string y_s = std::to_string(y);
+ c += " int yc" + y_s + " = min(Y + " + y_s + ", args.src_tensor.Height() - 1);\n";
+ }
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ std::string y_s = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ std::string x_s = std::to_string(x);
+ std::string i_s = std::to_string(y * block_size.x + x);
+ c += " int src_addr_" + i_s + " = (yc" + y_s + ") * args.src_tensor.Width() + (xc" + x_s +
+ ");\n";
+ }
+ }
+ c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ std::string y_s = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ std::string x_s = std::to_string(x);
+ std::string i_s = std::to_string(y * block_size.x + x);
+ c += " FLT" + std::to_string(element_size * 4) + " s" + i_s +
+ " = args.src_tensor.Read(src_addr_" + i_s + ");\n";
+ }
+ }
+ c += GetComputationPart(block_size, element_size, op_def.precision);
+ for (int i = 0; i < block_size.x * block_size.y; ++i)
+ {
+ std::string i_s = std::to_string(i);
+ c += " src_addr_" + i_s + " += args.src_tensor.SliceStride();\n";
+ }
+ c += " weights_cache += " + std::to_string(block_size.z) + ";\n";
+ c += " }\n"; // SRC_SLICES
+
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string z_s = std::to_string(z);
+ if (z != 0)
+ {
+ c += " if (Z + " + z_s + " >= args.dst_tensor.Slices()) return;\n";
+ }
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string y_s = std::to_string(y);
+ for (int x = 0; x < block_size.x * element_size; ++x)
+ {
+ const std::string x_s = std::to_string(x);
+ c += " if (X + " + x_s + " < args.dst_tensor.Width() && Y + " + y_s +
+ " < args.dst_tensor.Height()) {\n";
+ c += " FLT4 res = TO_FLT4(r" + z_s + y_s + x_s + ");\n";
+ c += " args.dst_tensor.Write(res, X + " + x_s + ", Y + " + y_s + ", Z + " + z_s + ");\n";
+ c += " }\n";
+ }
+ }
+ }
+ c += "}\n";
+ return c;
+}
+
+int3 ConvBuffer1x1::GetGridSize() const
+{
+ const int dst_width_elements =
+ DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), (conv_params_.element_size / 4));
+ const int grid_x = DivideRoundUp(dst_width_elements, conv_params_.block_size.x);
+ const int grid_y = DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
+ const int grid_z = DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z);
+ return int3(grid_x, grid_y, grid_z);
+}
+
+void ConvBuffer1x1::GetPossibleKernelWorkGroups(TuningType tuning_type,
+ const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const
+{
+ GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, work_groups);
+}
+
+bool IsConvBuffer1x1Supported(const OperationDef &definition, const Convolution2DAttributes &attr)
+{
+ auto src_storage_type = definition.src_tensors[0].storage_type;
+ return src_storage_type == TensorStorageType::BUFFER && attr.weights.shape.w == 1 &&
+ attr.weights.shape.h == 1 && attr.dilations.w == 1 && attr.dilations.h == 1 &&
+ attr.strides.w == 1 && attr.strides.h == 1 && attr.padding.prepended.w == 0 &&
+ attr.padding.prepended.h == 0 && attr.padding.appended.w == 0 &&
+ attr.padding.appended.h == 0;
+}
+
+bool IsConvBuffer1x1Supported(const OperationDef &definition, const BHWC &weights_shape,
+ const Convolution2DAttributes &attr)
+{
+ auto src_storage_type = definition.src_tensors[0].storage_type;
+ return src_storage_type == TensorStorageType::BUFFER && weights_shape.w == 1 &&
+ weights_shape.h == 1 && attr.dilations.w == 1 && attr.dilations.h == 1 &&
+ attr.strides.w == 1 && attr.strides.h == 1 && attr.padding.prepended.w == 0 &&
+ attr.padding.prepended.h == 0 && attr.padding.appended.w == 0 &&
+ attr.padding.appended.h == 0;
+}
+
+ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ ConvBuffer1x1::ConvParams conv_params;
+ if (shape)
+ {
+ conv_params = GetBestParams(device_info, definition, *shape, src_depth, dst_depth);
+ }
+ else
+ {
+ conv_params = GetBestParams(device_info, definition, src_depth, dst_depth);
+ }
+ ConvBuffer1x1 result(definition, conv_params);
+ result.UploadData(attr.weights, attr.bias);
+ return result;
+}
+
+ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info, const OperationDef &definition,
+ const FullyConnectedAttributes &attr, const BHWC *shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ ConvBuffer1x1::ConvParams conv_params;
+ if (shape)
+ {
+ conv_params = GetBestParams(device_info, definition, *shape, src_depth, dst_depth);
+ }
+ else
+ {
+ conv_params = GetBestParams(device_info, definition, src_depth, dst_depth);
+ }
+ conv_params.block_size.x *= conv_params.block_size.y;
+ conv_params.block_size.y = 1;
+ ConvBuffer1x1 result(definition, conv_params);
+ result.UploadData(attr.weights, attr.bias);
+ return result;
+}
+
+ConvBuffer1x1 CreateConvBuffer1x1Wino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ ConvBuffer1x1::ConvParams conv_params;
+ if (shape)
+ {
+ conv_params = GetBestParams(device_info, definition, *shape, src_depth, dst_depth);
+ }
+ else
+ {
+ conv_params = GetBestParams(device_info, definition, src_depth, dst_depth);
+ }
+ conv_params.block_size.x *= conv_params.block_size.y;
+ conv_params.block_size.y = 1;
+ conv_params.different_weights_for_height = true;
+ ConvBuffer1x1 result(definition, conv_params);
+ result.UploadDataForWinograd4x4To6x6(attr.weights);
+ return result;
+}
+
+ConvBuffer1x1 CreateConvBuffer1x1DynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape, const BHWC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(weights_shape.b, 4);
+ const int src_depth = DivideRoundUp(weights_shape.c, 4);
+ ConvBuffer1x1::ConvParams conv_params;
+ if (dst_shape)
+ {
+ conv_params = GetBestParams(device_info, definition, *dst_shape, src_depth, dst_depth);
+ }
+ else
+ {
+ conv_params = GetBestParams(device_info, definition, src_depth, dst_depth);
+ }
+ ConvBuffer1x1 result(definition, conv_params);
+ result.UploadBiases(attr.bias);
+ return result;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_BUFFER_1X1_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_BUFFER_1X1_H__
+
+#include "open_cl/Buffer.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/kernels/ConvCommon.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/LinearStorage.h"
+#include "open_cl/Precision.h"
+#include "open_cl/InternalTensor.h"
+#include "open_cl/Util.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+#include "open_cl/WinogradUtil.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ConvBuffer1x1 : public GPUOperation
+{
+public:
+ ConvBuffer1x1() = default;
+
+ // Move only
+ ConvBuffer1x1(ConvBuffer1x1 &&operation);
+ ConvBuffer1x1 &operator=(ConvBuffer1x1 &&operation);
+ ConvBuffer1x1(const ConvBuffer1x1 &) = delete;
+ ConvBuffer1x1 &operator=(const ConvBuffer1x1 &) = delete;
+
+ void GetPossibleKernelWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const override;
+ int3 GetGridSize() const override;
+
+ ConvWeightsDescription GetConvWeightsDescription() const
+ {
+ ConvWeightsDescription desc;
+ desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
+ desc.output_group_size = conv_params_.block_size.z;
+ return desc;
+ }
+
+ struct ConvParams
+ {
+ int3 block_size = int3(1, 1, 1);
+ int element_size = 4; // can be 4, 8 or 16
+
+ // By default in 2d convolution we have the same weights for WH dims, but in
+ // some cases we need separate weights for H dimension and convolution
+ // kernel requires very small modifications to support it.
+ bool different_weights_for_height = false;
+ };
+
+private:
+ ConvBuffer1x1(const OperationDef &definition, const ConvParams &conv_params);
+ friend ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *shape);
+ friend ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const FullyConnectedAttributes &attr, const BHWC *shape);
+ friend ConvBuffer1x1 CreateConvBuffer1x1Wino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *shape);
+ friend ConvBuffer1x1 CreateConvBuffer1x1DynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape,
+ const BHWC *dst_shape);
+
+ template <DataType T>
+ void UploadData(const InternalTensor<OHWI, T> &weights, const InternalTensor<Linear, T> &biases);
+ template <DataType T> void UploadDataForWinograd4x4To6x6(const InternalTensor<OHWI, T> &weights);
+
+ template <DataType T> void UploadWeights(const InternalTensor<OHWI, T> &weights);
+
+ template <DataType T> void UploadBiases(const InternalTensor<Linear, T> &biases);
+
+ std::string GenerateConvBuffer1x1(const OperationDef &op_def,
+ const ConvBuffer1x1::ConvParams &conv_params, Arguments *args);
+
+ ConvParams conv_params_;
+};
+
+template <DataType T>
+void ConvBuffer1x1::UploadData(const InternalTensor<OHWI, T> &weights,
+ const InternalTensor<Linear, T> &biases)
+{
+ UploadWeights(weights);
+ UploadBiases(biases);
+}
+
+template <DataType T>
+void ConvBuffer1x1::UploadDataForWinograd4x4To6x6(const InternalTensor<OHWI, T> &weights)
+{
+ InternalTensor<OHWI, T> wino_weights;
+ RearrangeWeightsToWinograd4x4To6x6Weights(weights, &wino_weights);
+ UploadWeights(wino_weights);
+ InternalTensor<Linear, DataType::FLOAT32> bias;
+ bias.shape = Linear(weights.shape.o);
+ bias.data.resize(weights.shape.o, 0.0f);
+ UploadBiases(bias);
+}
+
+template <DataType T> void ConvBuffer1x1::UploadWeights(const InternalTensor<OHWI, T> &weights)
+{
+ const int dst_depth = DivideRoundUp(weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(weights.shape.i, 4);
+
+ const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
+ const int float4_size = sizeof(float4);
+ // TODO
+ // f32_weights ? sizeof(float4) : sizeof(half4);
+
+ const int dst_depth_aligned = AlignByN(dst_depth, conv_params_.block_size.z);
+ const int elements_count = weights.shape.h * weights.shape.w * src_depth * dst_depth_aligned * 4;
+
+ BufferDescriptor desc;
+ desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 16;
+ desc.memory_type = MemoryType::GLOBAL;
+ desc.size = float4_size * elements_count;
+ desc.data.resize(desc.size);
+
+ if (f32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(desc.data.data());
+ RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z,
+ absl::MakeSpan(ptr, elements_count));
+ }
+ // else
+ // {
+ // half4 *ptr = reinterpret_cast<half4 *>(desc.data.data());
+ // RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.z,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+
+ args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc)));
+}
+
+template <DataType T> void ConvBuffer1x1::UploadBiases(const InternalTensor<Linear, T> &biases)
+{
+ TensorLinearDescriptor desc;
+ desc.storage_type = LinearStorageType::BUFFER;
+ desc.element_type = definition_.GetDataType();
+ int depth = AlignByN(biases.shape.v, 4 * conv_params_.block_size.z) / 4;
+ desc.UploadLinearData(biases, depth);
+ args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+}
+
+bool IsConvBuffer1x1Supported(const OperationDef &definition, const Convolution2DAttributes &attr);
+
+bool IsConvBuffer1x1Supported(const OperationDef &definition, const BHWC &weights_shape,
+ const Convolution2DAttributes &attr);
+
+ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *shape = nullptr);
+
+ConvBuffer1x1 CreateConvBuffer1x1(const DeviceInfo &device_info, const OperationDef &definition,
+ const FullyConnectedAttributes &attr,
+ const BHWC *shape = nullptr);
+
+ConvBuffer1x1 CreateConvBuffer1x1DynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape,
+ const BHWC *dst_shape = nullptr);
+
+ConvBuffer1x1 CreateConvBuffer1x1Wino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *shape = nullptr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_BUFFER_1X1_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_COMMON_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_COMMON_H__
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class ConvWeightsLayout
+{
+ kUnknown,
+ kOHWIOGroupI4O4,
+};
+
+struct ConvWeightsDescription
+{
+ ConvWeightsLayout layout;
+ int output_group_size;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_COMMON_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "open_cl/kernels/ConvConstants.h"
+
+#include <string>
+#include <utility>
+
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/Precision.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+// Adreno can provide up to ~3-4KB of constant memory, but in some cases even
+// 3KB can have very bad performance.
+int GetAdrenoOptimalMaxConstantSize(int gpu_version)
+{
+ if (gpu_version < 600)
+ {
+ return 256 * 10; // 2.5KB
+ }
+ else
+ {
+ return 256 * 14; // 3.5KB
+ }
+}
+
+int GetOptimalMaxConstantSize(const DeviceInfo &info)
+{
+ if (!info.IsAdreno())
+ {
+ // In general we do not expect that this kernel will be used with non Adreno
+ // so as it tuned for __constant memory that have big profit on Adreno
+ return 1024; // 1KB
+ }
+ else
+ {
+ return GetAdrenoOptimalMaxConstantSize(info.adreno_info.gpu_version);
+ }
+}
+
+std::string GenerateConvolutionConstantCode(const OperationDef &op_def, const OHWI &weights_shape,
+ bool stride_correction, GPUOperation *op)
+{
+ auto src_desc = op_def.src_tensors[0];
+ src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddSrcTensor("src_tensor", src_desc);
+
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddDstTensor("dst_tensor", dst_desc);
+
+ std::string c = GetCommonDefines(op_def.precision);
+
+ const int out_z = DivideRoundUp(weights_shape.o, 4);
+ const std::string kOutZ = std::to_string(out_z);
+ const int src_depth = DivideRoundUp(weights_shape.i, 4);
+
+ const auto src_tensor_type = op_def.src_tensors[0].storage_type;
+ const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
+ src_tensor_type == TensorStorageType::IMAGE_BUFFER;
+
+ switch (op_def.precision)
+ {
+ case CalculationsPrecision::F32:
+ case CalculationsPrecision::F16:
+ c += "#define CONV4(R, SRC, F, i) \\\n";
+ c += " R += SRC.x * F[i + 0]; \\\n";
+ c += " R += SRC.y * F[i + 1]; \\\n";
+ c += " R += SRC.z * F[i + 2]; \\\n";
+ c += " R += SRC.w * F[i + 3]; \n";
+
+ c += "#define CONV3(R, SRC, F, i) \\\n";
+ c += " R += SRC.x * F[i + 0]; \\\n";
+ c += " R += SRC.y * F[i + 1]; \\\n";
+ c += " R += SRC.z * F[i + 2]; \n";
+
+ c += "#define CONV2(R, SRC, F, i) \\\n";
+ c += " R += SRC.x * F[i + 0]; \\\n";
+ c += " R += SRC.y * F[i + 1]; \n";
+
+ c += "#define CONV1(R, SRC, F, i) \\\n";
+ c += " R += SRC * F[i + 0]; \n";
+ break;
+ case CalculationsPrecision::F32_F16:
+ c += "#define CONV4(R, SRC, F, i) \\\n";
+ c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]";
+ c += " + SRC.z * F[i + 2] + SRC.w * F[i + 3]);\n";
+
+ c += "#define CONV3(R, SRC, F, i) \\\n";
+ c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]";
+ c += " + SRC.z * F[i + 2]);\n";
+
+ c += "#define CONV2(R, SRC, F, i) \\\n";
+ c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]);\n";
+
+ c += "#define CONV1(R, SRC, F, i) \\\n";
+ c += " R += convert_float4(SRC * F[i + 0]);\n";
+ break;
+ }
+
+ const std::string postfixes[] = {".x", ".xy", ".xyz", ""};
+
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ c += " int Y = get_global_id(1);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
+ "return;\n";
+ if (stride_correction)
+ {
+ c += " int start_x = " +
+ GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", "args.padding_x") +
+ ";\n";
+ }
+ else
+ {
+ if (op_def.IsBatchSupported())
+ {
+ c += " int start_x = X * args.stride_x + args.padding_x * "
+ "args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int start_x = X * args.stride_x + args.padding_x;\n";
+ }
+ }
+ c += " int start_y = Y * args.stride_y + args.padding_y;\n";
+ c += " ACCUM_FLT4 r[" + kOutZ + "];\n";
+ c += " for (int i = 0; i < " + kOutZ + "; ++i) {\n";
+ c += " r[i] = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ c += " }\n";
+ int filters_counter = 0;
+ for (int s = 0; s < src_depth; ++s)
+ {
+ const int ch_count = std::min(4, weights_shape.i - s * 4);
+ const std::string s_conv = "CONV" + std::to_string(ch_count);
+ const std::string s_count = ch_count == 1 ? "" : std::to_string(ch_count);
+ const std::string s_type = absl::StrCat("FLT", s_count);
+ const std::string s_postfix = postfixes[ch_count - 1];
+ const std::string dilation_x =
+ op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" : "args.dilation_x";
+ for (int ky = 0; ky < weights_shape.h; ++ky)
+ {
+ std::string s_y = absl::StrCat("(start_y + ", ky, " * args.dilation_y)");
+ if (manual_clamp)
+ {
+ c += " {\n";
+ c += " bool y_out = " + s_y + " < 0 || " + s_y + " >= args.src_tensor.Height();\n";
+ }
+ for (int kx = 0; kx < weights_shape.w; ++kx)
+ {
+ c += " {\n";
+ std::string s_x = absl::StrCat("(start_x + ", kx, " * " + dilation_x + ")");
+ if (manual_clamp)
+ {
+ c += " bool x_out = " + s_x + "< 0 || " + s_x + ">= args.src_tensor.Width();\n";
+ c += " " + s_type + " src = x_out || y_out ?";
+ c += "(" + s_type + ")(0.0) : args.src_tensor.Read(" + s_x + ", " + s_y + ", " +
+ std::to_string(s) + ")" + s_postfix + ";\n";
+ }
+ else
+ {
+ c += " " + s_type + " src = args.src_tensor.Read(" + s_x + ", " + s_y + ", " +
+ std::to_string(s) + ")" + s_postfix + ";\n";
+ }
+ for (int d = 0; d < out_z; ++d)
+ {
+ c += " " + s_conv + "(r[" + std::to_string(d) + "], src, args.weigths.GetPtr(),";
+ c += " " + std::to_string(filters_counter) + ");\n";
+ filters_counter += ch_count;
+ }
+ c += " }\n";
+ }
+ if (manual_clamp)
+ {
+ c += " }\n";
+ }
+ }
+ }
+ for (int i = 0; i < out_z; ++i)
+ {
+ std::string s_i = std::to_string(i);
+ c += " {\n";
+ c += " FLT4 res = TO_FLT4(r[" + s_i + "]) + args.biases.Read(" + s_i + ");\n";
+ c += " args.dst_tensor.Write(res, X, Y, " + s_i + ");\n";
+ c += " }\n";
+ }
+ c += "}\n";
+ return c;
+}
+
+} // namespace
+
+bool IsConvConstantsSupported(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr)
+{
+ if (device_info.IsAMD() && definition.precision != CalculationsPrecision::F32 &&
+ definition.src_tensors[0].storage_type != TensorStorageType::BUFFER)
+ {
+ // BUG, some AMD gpus crashe without it
+ return false;
+ }
+
+ const auto &w_shape = attr.weights.shape;
+ const int dst_channels = AlignByN(w_shape.o, 4);
+ const int filters_count = w_shape.i * dst_channels * w_shape.h * w_shape.w;
+ const int float_size = sizeof(float);
+ // TODO F32 and F16
+ // definition.precision == CalculationsPrecision::F32 ? sizeof(float) : sizeof(half);
+ const int filters_buffer_size = filters_count * float_size;
+ const int kConstantMaxSize = GetOptimalMaxConstantSize(device_info);
+ const int flt4_registers = DivideRoundUp(w_shape.o, 4);
+ return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8;
+}
+
+GPUOperation CreateConvConstants(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr)
+{
+ GPUOperation op(definition);
+ UploadWeightsForConvConstants(attr.weights, definition.precision, &op);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("stride_y", attr.strides.h);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("dilation_x", attr.dilations.w);
+ op.args_.AddInt("dilation_y", attr.dilations.h);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1;
+
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ op.code_ =
+ GenerateConvolutionConstantCode(definition, attr.weights.shape, stride_correction, &op);
+ if (definition.precision == CalculationsPrecision::F16 && device_info.IsAdreno3xx())
+ {
+ op.compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE);
+ }
+ if (definition.precision != CalculationsPrecision::F32 && device_info.IsPowerVR())
+ {
+ // BUG, some PowerVRs (GE8320) produce incorrect result without it
+ op.compiler_options_.push_back(CompilerOptions::CL_OPT_DISABLE);
+ }
+
+ TensorLinearDescriptor desc;
+ desc.storage_type = LinearStorageType::BUFFER;
+ desc.element_type = definition.GetDataType();
+ desc.memory_type = MemoryType::CONSTANT;
+ desc.UploadLinearData(attr.bias);
+ op.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_CONSTANTS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_CONSTANTS_H__
+
+#include "open_cl/Buffer.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/LinearStorage.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Util.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <DataType S, typename T>
+void RearrangeWeightsForConvConstants(const InternalTensor<OHWI, S> &weights, absl::Span<T> dst)
+{
+ const int dst_depth = DivideRoundUp(weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(weights.shape.i, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+
+ int counter = 0;
+ for (int s = 0; s < src_depth; ++s)
+ {
+ for (int y = 0; y < kernel_y; ++y)
+ {
+ for (int x = 0; x < kernel_x; ++x)
+ {
+ for (int d = 0; d < dst_depth; ++d)
+ {
+ const int channels_count = std::min(4, weights.shape.i - s * 4);
+ T filters[4];
+ for (int i = 0; i < 4; ++i)
+ {
+ for (int j = 0; j < channels_count; ++j)
+ {
+ const int s_ch = s * 4 + j;
+ const int d_ch = d * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o)
+ {
+ const int f_index = weights.shape.LinearIndex({d_ch, y, x, s_ch});
+ filters[i][j] = weights.data[f_index];
+ }
+ else
+ {
+ filters[i][j] = 0.0f;
+ }
+ }
+ }
+ T filters_new[4];
+ for (int i = 0; i < 4; ++i)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ filters_new[i][j] = filters[j][i];
+ }
+ }
+ for (int i = 0; i < channels_count; ++i)
+ {
+ dst[counter++] = filters_new[i];
+ }
+ }
+ }
+ }
+ }
+}
+
+template <DataType T>
+void UploadWeightsForConvConstants(const InternalTensor<OHWI, T> &weights,
+ CalculationsPrecision precision, GPUOperation *op)
+{
+ const int dst_depth = DivideRoundUp(weights.shape.o, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+
+ const bool f32_weights = precision == CalculationsPrecision::F32;
+ const int float_size = f32_weights ? 4 : 2;
+ const int float_count = weights.shape.i * dst_depth * 4 * kernel_x * kernel_y;
+
+ BufferDescriptor desc;
+ desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 4;
+ desc.memory_type = MemoryType::CONSTANT;
+ desc.size = float_size * float_count;
+ desc.data.resize(desc.size);
+
+ if (f32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(desc.data.data());
+ RearrangeWeightsForConvConstants(weights, absl::MakeSpan(ptr, float_count / 4));
+ }
+ // else
+ // {
+ // half4 *ptr = reinterpret_cast<half4 *>(desc.data.data());
+ // RearrangeWeightsForConvConstants(weights, absl::MakeSpan(ptr, float_count / 4));
+ // }
+
+ op->args_.AddObject("weigths", absl::make_unique<BufferDescriptor>(std::move(desc)));
+}
+
+bool IsConvConstantsSupported(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr);
+
+GPUOperation CreateConvConstants(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_CONSTANTS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "open_cl/kernels/ConvPowervr.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "absl/strings/substitute.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/Precision.h"
+#include "open_cl/TensorType.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+std::string GenerateUploadByThreads(const std::string &local_ptr_name,
+ const std::string &global_ptr_name,
+ const std::string &global_offset_name,
+ const std::string &lid_name, int total_work_items,
+ int elements_to_upload)
+{
+ std::string c;
+ std::string offset = global_offset_name.empty() ? "" : global_offset_name + " + ";
+ const int groups = elements_to_upload / total_work_items;
+ const int reminder = elements_to_upload % total_work_items;
+ for (int i = 0; i < groups; ++i)
+ {
+ c += " " + local_ptr_name + "[" + lid_name + " + " + std::to_string(total_work_items * i) +
+ "] = " + global_ptr_name + "[" + offset + lid_name + " + " +
+ std::to_string(total_work_items * i) + "];\n";
+ }
+ if (reminder != 0)
+ {
+ c += " if (" + lid_name + " < " + std::to_string(reminder) + ") {\n";
+ c += " " + local_ptr_name + "[" + lid_name + " + " +
+ std::to_string(total_work_items * groups) + "] = " + global_ptr_name + "[" + offset +
+ lid_name + " + " + std::to_string(total_work_items * groups) + "];\n";
+ c += " }\n";
+ }
+ return c;
+}
+
+std::string GenerateAsyncUpload(const std::string &local_ptr_name,
+ const std::string &global_ptr_name,
+ const std::string &global_offset_name, int elements_to_upload)
+{
+ std::string c;
+ std::string offset = global_offset_name.empty() ? "" : " + " + global_offset_name;
+ c += " async_work_group_copy(" + local_ptr_name + ", " + global_ptr_name + offset + ", " +
+ std::to_string(elements_to_upload) + ", 0);\n";
+ return c;
+}
+
+std::string GenerateBlockCoords(const int4 &block_size, const int3 &work_group_launch_order,
+ bool linear_spatial, bool need_depth)
+{
+ std::string c;
+ int3 launch_remap;
+ launch_remap[work_group_launch_order.x] = 0;
+ launch_remap[work_group_launch_order.y] = 1;
+ launch_remap[work_group_launch_order.z] = 2;
+ if (linear_spatial)
+ {
+ if (work_group_launch_order[0] == 0)
+ {
+ c += " int linear_spatial = get_global_id(0);\n";
+ }
+ else
+ {
+ c += " int linear_spatial = get_group_id(" + std::to_string(launch_remap[0]) +
+ ") * get_local_size(0) + get_local_id(0);\n";
+ }
+ if (need_depth)
+ {
+ c += " int DST_X = (linear_spatial % args.task_size_x) * " + std::to_string(block_size.x) +
+ ";\n";
+ c += " linear_spatial = linear_spatial / args.task_size_x;\n";
+ c += " int DST_Y = (linear_spatial % args.task_size_y) * " + std::to_string(block_size.y) +
+ ";\n";
+ c += " int DST_Z = (linear_spatial / args.task_size_y) * " + std::to_string(block_size.z) +
+ ";\n";
+ }
+ else
+ {
+ c += " int DST_Y = (linear_spatial / args.task_size_x) * " + std::to_string(block_size.y) +
+ ";\n";
+ c += " int DST_X = (linear_spatial % args.task_size_x) * " + std::to_string(block_size.x) +
+ ";\n";
+ }
+ if (work_group_launch_order[1] == 1)
+ {
+ c += " int DST_S = get_global_id(1) * " + std::to_string(block_size.w) + ";\n";
+ }
+ else
+ {
+ c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[1]) +
+ ") * get_local_size(1) + get_local_id(1)) * " + std::to_string(block_size.w) + ";\n";
+ }
+ }
+ else
+ {
+ if (work_group_launch_order[0] == 0)
+ {
+ c += " int DST_X = get_global_id(0) * " + std::to_string(block_size.x) + ";\n";
+ }
+ else
+ {
+ c += " int DST_X = (get_group_id(" + std::to_string(launch_remap[0]) +
+ ") * get_local_size(0) + get_local_id(0)) * " + std::to_string(block_size.x) + ";\n";
+ }
+ std::string global_id_1;
+ if (work_group_launch_order[1] == 1)
+ {
+ global_id_1 = "get_global_id(1)";
+ }
+ else
+ {
+ global_id_1 = "(get_group_id(" + std::to_string(launch_remap[1]) +
+ ") * get_local_size(1) + get_local_id(1))";
+ }
+ if (need_depth)
+ {
+ c += " int linear_id_1 = " + global_id_1 + ";\n";
+ c +=
+ " int DST_Z = (linear_id_1 / args.task_size_y) * " + std::to_string(block_size.z) + ";\n";
+ c +=
+ " int DST_Y = (linear_id_1 % args.task_size_y) * " + std::to_string(block_size.y) + ";\n";
+ }
+ else
+ {
+ c += " int DST_Y = " + global_id_1 + " * " + std::to_string(block_size.y) + ";\n";
+ }
+ if (work_group_launch_order[2] == 2)
+ {
+ c += " int DST_S = get_global_id(2) * " + std::to_string(block_size.w) + ";\n";
+ }
+ else
+ {
+ c += " int DST_S = (get_group_id(" + std::to_string(launch_remap[2]) +
+ ") * get_local_size(2) + get_local_id(2)) * " + std::to_string(block_size.w) + ";\n";
+ }
+ }
+
+ return c;
+}
+} // namespace
+
+ConvPowerVR::ConvPowerVR(const OperationDef &definition, const Convolution2DAttributes &attr,
+ const DeviceInfo &device_info, const BHWC *dst_shape)
+ : GPUOperation(definition), stride_(attr.strides.w, attr.strides.h, 1, 1),
+ padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
+ kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 1, 1),
+ dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
+ conv_params_(GuessBestParams(device_info, definition, attr, dst_shape))
+{
+}
+
+ConvPowerVR::ConvPowerVR(const OperationDef &definition, const Convolution2DAttributes &attr,
+ const BHWC &weights_shape, const DeviceInfo &device_info,
+ const BHWC *dst_shape)
+ : GPUOperation(definition), stride_(attr.strides.w, attr.strides.h, 1, 1),
+ padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
+ kernel_size_(weights_shape.w, weights_shape.h, 1, 1),
+ dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
+ conv_params_(GuessBestParams(device_info, definition, attr, weights_shape, dst_shape))
+{
+}
+
+ConvPowerVR::ConvPowerVR(const OperationDef &definition, const FullyConnectedAttributes &attr,
+ const DeviceInfo &device_info, const BHWC *dst_shape)
+ : GPUOperation(definition), stride_(1, 1, 1, 1), padding_(0, 0, 0, 0), kernel_size_(1, 1, 1, 1),
+ dilation_(1, 1, 1, 1), conv_params_(GuessBestParams(device_info, definition, attr, dst_shape))
+{
+}
+
+ConvPowerVR::ConvPowerVR(const OperationDef &definition)
+ : GPUOperation(definition), stride_(1, 1, 1, 1), padding_(0, 0, 0, 0), kernel_size_(1, 1, 1, 1),
+ dilation_(1, 1, 1, 1)
+{
+}
+
+ConvPowerVR::ConvPowerVR(ConvPowerVR &&operation)
+ : GPUOperation(std::move(operation)), stride_(operation.stride_), padding_(operation.padding_),
+ kernel_size_(operation.kernel_size_), dilation_(operation.dilation_),
+ conv_params_(operation.conv_params_)
+{
+}
+
+ConvPowerVR::ConvPowerVR(const OperationDef &definition, const Convolution3DAttributes &attr,
+ const DeviceInfo &device_info, const BHWDC *dst_shape)
+ : GPUOperation(definition), stride_(attr.strides.w, attr.strides.h, attr.strides.d, 1),
+ padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, -attr.padding.prepended.d, 0),
+ kernel_size_(attr.weights.shape.w, attr.weights.shape.h, attr.weights.shape.d, 1),
+ dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d, 1),
+ conv_params_(GuessBestParams(device_info, definition, attr, dst_shape))
+{
+}
+
+ConvPowerVR &ConvPowerVR::operator=(ConvPowerVR &&operation)
+{
+ if (this != &operation)
+ {
+ std::swap(stride_, operation.stride_);
+ std::swap(padding_, operation.padding_);
+ std::swap(kernel_size_, operation.kernel_size_);
+ std::swap(dilation_, operation.dilation_);
+ std::swap(conv_params_, operation.conv_params_);
+ GPUOperation::operator=(std::move(operation));
+ }
+ return *this;
+}
+
+void ConvPowerVR::GenerateCode(const DeviceInfo &device_info)
+{
+ if (conv_params_.linear_spatial)
+ {
+ grid_dimension_ = 2;
+ }
+ const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1;
+ code_ = GenerateConv(device_info, definition_, stride_correction, conv_params_);
+ if (definition_.precision == CalculationsPrecision::F16 && device_info.IsPowerVR())
+ {
+ compiler_options_.push_back(CompilerOptions::POWERVR_FP16);
+ }
+ if (conv_params_.IsPrivateMemBroadcast() && device_info.IsCL20OrHigher())
+ {
+ compiler_options_.push_back(CompilerOptions::CL_2_0);
+ }
+ bool kernel_is_trivial = conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1;
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ kernel_is_trivial = kernel_is_trivial & conv_params_.z_kernel_is_1;
+ }
+ if (device_info.IsAdreno3xx() && definition_.precision == CalculationsPrecision::F16 &&
+ kernel_is_trivial)
+ {
+ compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE);
+ }
+}
+
+absl::Status ConvPowerVR::BindArguments(ArgumentsBinder *args)
+{
+ if (!conv_params_.x_kernel_is_1)
+ {
+ RETURN_IF_ERROR(args->SetInt("stride_x", stride_.x));
+ RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch()));
+ RETURN_IF_ERROR(args->SetInt("kernel_size_x", kernel_size_.x));
+ RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch()));
+ }
+ if (!conv_params_.y_kernel_is_1)
+ {
+ RETURN_IF_ERROR(args->SetInt("stride_y", stride_.y));
+ RETURN_IF_ERROR(args->SetInt("padding_y", padding_.y));
+ RETURN_IF_ERROR(args->SetInt("kernel_size_y", kernel_size_.y));
+ RETURN_IF_ERROR(args->SetInt("dilation_y", dilation_.y));
+ }
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1)
+ {
+ RETURN_IF_ERROR(args->SetInt("stride_z", stride_.z));
+ RETURN_IF_ERROR(args->SetInt("padding_z", padding_.z));
+ RETURN_IF_ERROR(args->SetInt("kernel_size_z", kernel_size_.z));
+ RETURN_IF_ERROR(args->SetInt("dilation_z", dilation_.z));
+ }
+ if (conv_params_.linear_spatial)
+ {
+ const int grid_x =
+ DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x);
+ RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x));
+ }
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ const int task_size_y = DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
+ RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y));
+ }
+ return absl::OkStatus();
+}
+
+int3 ConvPowerVR::GetGridSize() const
+{
+ const int task_size_x =
+ DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x);
+ const int task_size_y = DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
+ const int task_size_z = DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z);
+ const int task_size_s = DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w);
+ int3 wg;
+
+ if (conv_params_.linear_spatial)
+ {
+ int grid_x = task_size_x * task_size_y;
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ grid_x *= task_size_z;
+ }
+ return int3(grid_x, task_size_s, 1);
+ }
+ else
+ {
+ int grid_y = task_size_y;
+ if (definition_.src_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ grid_y *= task_size_z;
+ }
+ return int3(task_size_x, grid_y, task_size_s);
+ }
+}
+
+void ConvPowerVR::GetPossibleKernelWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const
+{
+ if (conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
+ conv_params_.weights_upload_type == WeightsUploadType::LOCAL_MEM_BY_THREADS ||
+ conv_params_.fixed_work_group_size)
+ {
+ work_groups->push_back(work_group_size_);
+ return;
+ }
+ GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, work_groups);
+}
+
+std::string ConvPowerVR::GenerateConv(const DeviceInfo &device_info, const OperationDef &op_def,
+ bool stride_correction, const ConvParams &conv_params)
+{
+ auto src_desc = op_def.src_tensors[0];
+ src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ AddSrcTensor("src_tensor", src_desc);
+ if (op_def.src_tensors.size() == 2)
+ {
+ // dynamic weights
+ BufferDescriptor desc;
+ desc.element_type = op_def.src_tensors[1].data_type;
+ desc.element_size = 4;
+ desc.memory_type =
+ conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::CONSTANT_MEM
+ ? MemoryType::CONSTANT
+ : MemoryType::GLOBAL;
+
+ AddSrcBuffer("weights", desc);
+ }
+
+ const auto &src_def = op_def.src_tensors[0];
+
+ auto generate_id = [&](const std::string &x, const std::string &y, const std::string &z) {
+ std::string id;
+ if (src_def.HasAxis(Axis::WIDTH))
+ {
+ id += "_w" + x;
+ }
+ if (src_def.HasAxis(Axis::HEIGHT))
+ {
+ id += "_h" + y;
+ }
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ id += "_d" + z;
+ }
+ return id;
+ };
+
+ auto generate_id_full = [&](const std::string &x, const std::string &y, const std::string &z,
+ const std::string &s) { return generate_id(x, y, z) + "_s" + s; };
+
+ auto generate_check = [&](const std::string &x, const std::string &y, const std::string &z) {
+ std::string check;
+ const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
+ const std::vector<std::string> names{"in_x", "in_y", "in_z"};
+ const std::vector<bool> is_1{conv_params_.x_kernel_is_1, conv_params_.y_kernel_is_1,
+ conv_params_.z_kernel_is_1};
+ const std::vector<std::string> coords{x, y, z};
+ for (size_t i = 0; i < axes.size(); ++i)
+ {
+ const auto &axis = axes[i];
+ if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) && !is_1[i])
+ {
+ if (!check.empty())
+ {
+ check += " && ";
+ }
+ check += names[i] + coords[i];
+ }
+ }
+ return check;
+ };
+
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ AddDstTensor("dst_tensor", dst_desc);
+
+ if (!conv_params_.x_kernel_is_1)
+ {
+ args_.AddInt("stride_x");
+ args_.AddInt("padding_x");
+ args_.AddInt("kernel_size_x");
+ args_.AddInt("dilation_x");
+ }
+ if (!conv_params_.y_kernel_is_1)
+ {
+ args_.AddInt("stride_y");
+ args_.AddInt("padding_y");
+ args_.AddInt("kernel_size_y");
+ args_.AddInt("dilation_y");
+ }
+ if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1)
+ {
+ args_.AddInt("stride_z");
+ args_.AddInt("padding_z");
+ args_.AddInt("kernel_size_z");
+ args_.AddInt("dilation_z");
+ }
+ if (conv_params_.linear_spatial)
+ {
+ args_.AddInt("task_size_x");
+ }
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ args_.AddInt("task_size_y");
+ }
+
+ const bool need_local_mem =
+ conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
+ conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
+
+ const int local_mem_size = conv_params.block_size.w * 4 * conv_params.src_depth_loop_size;
+
+ const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast();
+ const int simd_size = conv_params.simd_size;
+
+ const bool late_oob_check = need_local_mem || use_simd_broadcast;
+
+ const std::string weights_space =
+ conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::CONSTANT_MEM ? "__constant"
+ : "__global";
+
+ const std::string weights_data_type =
+ conv_params.weights_data_type == DataType::FLOAT32 ? "float4" : "half4";
+
+ const std::string weights_global_ptr = weights_space + " " + weights_data_type + "*";
+
+ std::string c = GetCommonDefines(op_def.precision);
+ if (use_simd_broadcast)
+ {
+ if (device_info.cl_version == OpenCLVersion::CL_2_0)
+ {
+ c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
+ }
+ else if (device_info.SupportsExtension("cl_intel_subgroups"))
+ {
+ c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n";
+ }
+ }
+ const int4 block_size = conv_params.block_size;
+ if (conv_params.fixed_work_group_size)
+ {
+ c += "__attribute__((reqd_work_group_size(" + std::to_string(work_group_size_.x) + ", " +
+ std::to_string(work_group_size_.y) + ", " + std::to_string(work_group_size_.z) + ")))\n";
+ }
+ if (use_simd_broadcast && device_info.IsIntel())
+ {
+ c += "__attribute__((intel_reqd_sub_group_size(" + std::to_string(simd_size) + ")))\n";
+ }
+ std::string dst_oob_check;
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ if (conv_params.linear_spatial)
+ {
+ dst_oob_check = "DST_Z >= args.dst_tensor.Depth() || DST_S >= "
+ "args.dst_tensor.Slices()";
+ }
+ else
+ {
+ dst_oob_check = "DST_X >= args.dst_tensor.Width() || DST_Z >= "
+ "args.dst_tensor.Depth() || DST_S >= args.dst_tensor.Slices()";
+ }
+ }
+ else
+ {
+ if (conv_params.linear_spatial)
+ {
+ dst_oob_check = "DST_Y >= args.dst_tensor.Height() || DST_S >= "
+ "args.dst_tensor.Slices()";
+ }
+ else
+ {
+ dst_oob_check = "DST_X >= args.dst_tensor.Width() || DST_Y >= "
+ "args.dst_tensor.Height() || DST_S >= args.dst_tensor.Slices()";
+ }
+ }
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += GenerateBlockCoords(conv_params.block_size, work_group_launch_order_,
+ conv_params.linear_spatial, src_def.HasAxis(Axis::DEPTH));
+ if (!late_oob_check)
+ {
+ c += " if (" + dst_oob_check + ") {\n";
+ c += " return;\n";
+ c += " }\n";
+ }
+ if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS)
+ {
+ if (conv_params.linear_spatial)
+ {
+ c += " int lid = get_local_id(0);\n";
+ }
+ else
+ {
+ c += " int lid = get_local_id(1) * " + std::to_string(work_group_size_.x) +
+ " + get_local_id(0);\n";
+ }
+ }
+ if (use_simd_broadcast)
+ {
+ c += " int simd_id = get_sub_group_local_id();\n";
+ }
+ for (int s = 0; s < block_size.w; ++s)
+ {
+ const std::string sind = std::to_string(s);
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ c += " ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
+ " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ }
+ }
+ }
+ }
+ if (!conv_params_.x_kernel_is_1)
+ {
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ const std::string xc = "(DST_X + " + xind + ")";
+ if (stride_correction)
+ {
+ c += " int xc" + xind + " = " +
+ GetXStrideCorrected(xc, "args.src_tensor.Batch()", "args.stride_x", "args.padding_x") +
+ ";\n";
+ }
+ else
+ {
+ c += " int xc" + xind + " = " + xc + " * args.stride_x + args.padding_x;\n";
+ }
+ }
+ }
+ else
+ {
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ c += " int xc" + xind + " = DST_X + " + xind + ";\n";
+ if (!src_def.CanReadOutOfBorder(Axis::WIDTH))
+ {
+ c += " xc" + xind + " = clamp(xc" + xind + ", 0, args.src_tensor.Width() - 1);\n";
+ }
+ }
+ }
+ if (!conv_params_.y_kernel_is_1)
+ {
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ const std::string yc = "(DST_Y + " + yind + ")";
+ c += " int yc" + yind + " = " + yc + " * args.stride_y + args.padding_y;\n";
+ }
+ }
+ else
+ {
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ c += " int yc" + yind + " = DST_Y + " + yind + ";\n";
+ if (!src_def.CanReadOutOfBorder(Axis::HEIGHT))
+ {
+ c += " yc" + yind + " = clamp(yc" + yind + ", 0, args.src_tensor.Height() - 1);\n";
+ }
+ }
+ }
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ if (!conv_params_.z_kernel_is_1)
+ {
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ const std::string zc = "(DST_Z + " + zind + ")";
+ c += " int zc" + zind + " = " + zc + " * args.stride_z + args.padding_z;\n";
+ }
+ }
+ else
+ {
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ c += " int zc" + zind + " = DST_Z + " + zind + ";\n";
+ if (!src_def.CanReadOutOfBorder(Axis::DEPTH))
+ {
+ c += " zc" + zind + " = clamp(zc" + zind + ", 0, args.src_tensor.Depth() - 1);\n";
+ }
+ }
+ }
+ }
+ bool trivial_kernel_size = conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1;
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ trivial_kernel_size = trivial_kernel_size && conv_params_.z_kernel_is_1;
+ }
+ if (need_local_mem)
+ {
+ c += " __local " + weights_data_type + " weights_cache[" + std::to_string(local_mem_size) +
+ "];\n";
+ }
+ else if (conv_params.AreWeightsBuffer())
+ {
+ c += " " + weights_global_ptr + " weights_cache;\n";
+ }
+ else if (!trivial_kernel_size)
+ {
+ c += " int filter_offset = 0;\n";
+ }
+ if (conv_params.AreWeightsBuffer())
+ {
+ if (conv_params.different_weights_for_height)
+ {
+ c += " " + weights_global_ptr +
+ " filters_loc = args.weights.GetPtr() + (DST_S * "
+ "args.src_tensor.Height() + DST_Y * " +
+ std::to_string(block_size.w) + ") * 4 * args.src_tensor.Slices();\n";
+ }
+ else
+ {
+ std::string kernel_spatial_offset = "";
+ if (!conv_params_.x_kernel_is_1)
+ {
+ kernel_spatial_offset += " * args.kernel_size_x";
+ }
+ if (!conv_params_.y_kernel_is_1)
+ {
+ kernel_spatial_offset += " * args.kernel_size_y";
+ }
+ if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1)
+ {
+ kernel_spatial_offset += " * args.kernel_size_z";
+ }
+ c += " " + weights_global_ptr +
+ " filters_loc = args.weights.GetPtr() + DST_S * 4 * "
+ "args.src_tensor.Slices()" +
+ kernel_spatial_offset + ";\n";
+ }
+ }
+ if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1)
+ {
+ c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zck = "zck" + std::to_string(z);
+ c += " int zck" + std::to_string(z) + " = kz * args.dilation_z + zc" + std::to_string(z) +
+ ";\n";
+ if (!src_def.SupportsZeroClamp(Axis::DEPTH))
+ {
+ c += " bool in_z" + std::to_string(z) + " = " + zck + " >= 0 && " + zck +
+ " < args.src_tensor.Depth();\n";
+ if (!src_def.CanReadOutOfBorder(Axis::DEPTH))
+ {
+ c += " " + zck + " = clamp(" + zck + ", 0, args.src_tensor.Depth() - 1);\n";
+ }
+ }
+ }
+ }
+ if (!conv_params_.y_kernel_is_1)
+ {
+ c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yck = "yck" + std::to_string(y);
+ c += " int " + yck + " = ky * args.dilation_y + yc" + std::to_string(y) + ";\n";
+ if (!src_def.SupportsZeroClamp(Axis::HEIGHT))
+ {
+ c += " bool in_y" + std::to_string(y) + " = " + yck + " >= 0 && " + yck +
+ " < args.src_tensor.Height();\n";
+ if (!src_def.CanReadOutOfBorder(Axis::HEIGHT))
+ {
+ c += " " + yck + " = clamp(" + yck + ", 0, args.src_tensor.Height() - 1);\n";
+ }
+ }
+ }
+ }
+ if (!conv_params_.x_kernel_is_1)
+ {
+ c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xck = "xck" + std::to_string(x);
+ c += " int xck" + std::to_string(x) + " = kx * args.dilation_x + xc" + std::to_string(x) +
+ ";\n";
+ if (!src_def.SupportsZeroClamp(Axis::WIDTH))
+ {
+ c += " bool in_x" + std::to_string(x) + " = " + xck + " >= 0 && " + xck +
+ " < args.src_tensor.Width();\n";
+ if (!src_def.CanReadOutOfBorder(Axis::WIDTH))
+ {
+ c += " " + xck + " = clamp(" + xck + ", 0, args.src_tensor.Width() - 1);\n";
+ }
+ }
+ }
+ }
+ const bool need_multiple_slice_strides =
+ src_def.ReturnsZeroForNegOneRead() && !trivial_kernel_size;
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ std::string xc = conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind;
+ std::string yc = conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind;
+ const std::string id = generate_id(xind, yind, zind);
+ std::string coords = "" + xc + ", " + yc;
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ std::string zc = conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind;
+ coords += ", " + zc;
+ }
+ if (src_def.IsLinear())
+ {
+ c += " args.src_tensor.GetAddress(addr" + id + ", " + coords + ", 0);\n";
+ if (need_multiple_slice_strides)
+ {
+ const std::string check = generate_check(xind, yind, zind);
+ c += " addr" + id + " = select(-1, addr" + id + ", (" + check + "));\n";
+ c +=
+ " int ds" + id + " = select(0, args.src_tensor.SliceStride(), (" + check + "));\n";
+ }
+ }
+ }
+ }
+ }
+ if (src_def.IsLinear() && !need_multiple_slice_strides)
+ {
+ c += " int ds = args.src_tensor.SliceStride();\n";
+ }
+
+ auto declare_src = [&]() {
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ const std::string id = generate_id(xind, yind, zind);
+ c += " " + weights_data_type + " src" + id + ";\n";
+ }
+ }
+ }
+ };
+ const bool conditional_read = device_info.IsMali();
+ auto read_src = [&]() {
+ const std::string cl_type = ToCLDataType(conv_params.weights_data_type);
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ std::string id = generate_id(xind, yind, zind);
+ const std::string check = generate_check(xind, yind, zind);
+ std::string address;
+ if (src_def.IsLinear())
+ {
+ address = "addr" + id;
+ }
+ else
+ {
+ std::string xc = conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind;
+ std::string yc = conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind;
+ address = "" + xc + ", " + yc;
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ std::string zc = conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind;
+ address += ", " + zc;
+ }
+ address += ", s";
+ }
+ if (src_def.ReturnsZeroForNegOneRead())
+ {
+ c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" + address + ");\n";
+ const std::string ds = trivial_kernel_size ? "ds" : "ds" + id;
+ c += " " + address + " += " + ds + ";\n";
+ }
+ else
+ {
+ if (!check.empty())
+ {
+ if (conditional_read)
+ {
+ c += " src" + id + " = " + check + " ? args.src_tensor.Read<" + cl_type + ">(" +
+ address + ") : (FLT4)(0.0f);\n";
+ }
+ else
+ {
+ c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" + address +
+ ") * (FLT)(" + check + ");\n";
+ }
+ }
+ else
+ {
+ c += " src" + id + " = args.src_tensor.Read<" + cl_type + ">(" + address + ");\n";
+ }
+ if (src_def.IsLinear())
+ {
+ c += " " + address + " += ds;\n";
+ }
+ }
+ }
+ }
+ }
+ };
+ const bool weights_type_as_accum_type = !(op_def.precision == CalculationsPrecision::F32_F16 &&
+ conv_params.weights_data_type == DataType::FLOAT16);
+ auto conv_core = [&](int shared_offset) {
+ const std::string channels[] = {"x", "y", "z", "w"};
+ for (int s = 0; s < block_size.w; ++s)
+ {
+ const std::string sind = std::to_string(s);
+ if (weights_type_as_accum_type)
+ {
+ for (int ch = 0; ch < 4; ++ch)
+ {
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ std::string R = "r" + generate_id_full(xind, yind, zind, sind);
+ std::string S = "src" + generate_id(xind, yind, zind);
+ if (use_simd_broadcast)
+ {
+ int simd_id = (s * 4 + ch + shared_offset) / simd_size;
+ int thread_id = (s * 4 + ch + shared_offset) % simd_size;
+ std::string w_val_x = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
+ ".x, " + std::to_string(thread_id) + "u)";
+ std::string w_val_y = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
+ ".y, " + std::to_string(thread_id) + "u)";
+ std::string w_val_z = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
+ ".z, " + std::to_string(thread_id) + "u)";
+ std::string w_val_w = "sub_group_broadcast(simd_w" + std::to_string(simd_id) +
+ ".w, " + std::to_string(thread_id) + "u)";
+ c += " " + R + ".x += " + w_val_x + " * " + S + "." + channels[ch] + ";\n";
+ c += " " + R + ".y += " + w_val_y + " * " + S + "." + channels[ch] + ";\n";
+ c += " " + R + ".z += " + w_val_z + " * " + S + "." + channels[ch] + ";\n";
+ c += " " + R + ".w += " + w_val_w + " * " + S + "." + channels[ch] + ";\n";
+ }
+ else
+ {
+ const std::string weight_id = std::to_string(s * 4 + ch + shared_offset);
+ std::string w_val;
+ if (conv_params.AreWeightsBuffer())
+ {
+ w_val = "weights_cache[" + weight_id + "]";
+ }
+ else
+ {
+ w_val = "f" + weight_id;
+ }
+ c += " " + R + " += " + w_val + " * " + S + "." + channels[ch] + ";\n";
+ }
+ }
+ }
+ }
+ }
+ }
+ else
+ { // F32_F16 precision and weights type is float16
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ std::string R = "r" + generate_id_full(xind, yind, zind, sind);
+ std::string S = "src" + generate_id(xind, yind, zind);
+ std::vector<std::string> F(4);
+ for (int i = 0; i < 4; ++i)
+ {
+ std::string weight_id = std::to_string(s * 4 + i + shared_offset);
+ if (conv_params.AreWeightsBuffer())
+ {
+ F[i] = "weights_cache[" + weight_id + "]";
+ }
+ else
+ {
+ F[i] = "f" + weight_id;
+ }
+ }
+ c += " " + R + " += convert_float4(" + S + ".x * " + F[0] + " + " + S + ".y * " +
+ F[1] + " + " + S + ".z * " + F[2] + " + " + S + ".w * " + F[3] + ");\n";
+ }
+ }
+ }
+ }
+ }
+ };
+
+ c += " int s = 0;\n";
+ c += " do {\n";
+ declare_src();
+ const int total_work_items = work_group_size_.x * work_group_size_.y * work_group_size_.z;
+ if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP)
+ {
+ c += GenerateAsyncUpload("weights_cache", "filters_loc",
+ /*global_offset_name*/ "", local_mem_size);
+ }
+ else if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS)
+ {
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ c +=
+ GenerateUploadByThreads("weights_cache", "filters_loc",
+ /*global_offset_name*/ "", "lid", total_work_items, local_mem_size);
+ }
+ else if (use_simd_broadcast)
+ {
+ int parts = local_mem_size / simd_size;
+ int reminder = local_mem_size % simd_size;
+ for (int i = 0; i < parts; ++i)
+ {
+ c += " FLT4 simd_w" + std::to_string(i) + " = filters_loc[simd_id + " +
+ std::to_string(i * simd_size) + "];\n";
+ }
+ if (reminder)
+ {
+ c += " FLT4 simd_w" + std::to_string(parts) + ";\n";
+ c += " if (simd_id < " + std::to_string(reminder) + ") {\n";
+ c += " simd_w" + std::to_string(parts) + " = filters_loc[simd_id + " +
+ std::to_string(parts * simd_size) + "];\n";
+ c += " }\n";
+ }
+ }
+ else if (conv_params.AreWeightsBuffer())
+ { // GLOBAL_MEM/CONSTANT_MEM
+ c += " weights_cache = filters_loc;\n";
+ }
+ else
+ { // TEXTURES_MEM
+ for (int dst_s = 0; dst_s < block_size.w; ++dst_s)
+ {
+ std::string f_y = trivial_kernel_size ? "s" : "filter_offset";
+ if (conv_params.different_weights_for_height)
+ {
+ f_y = "DST_Y * args.src_tensor.Slices() + s";
+ }
+ c += absl::Substitute(
+ R"( FLT4 f$2 = args.weights0.Read(DST_S + $0, $1);
+ FLT4 f$3 = args.weights1.Read(DST_S + $0, $1);
+ FLT4 f$4 = args.weights2.Read(DST_S + $0, $1);
+ FLT4 f$5 = args.weights3.Read(DST_S + $0, $1);
+)",
+ dst_s, f_y, dst_s * 4 + 0, dst_s * 4 + 1, dst_s * 4 + 2, dst_s * 4 + 3);
+ }
+ if (!trivial_kernel_size)
+ {
+ c += " filter_offset++;\n";
+ }
+ }
+ read_src();
+ c += " s += 1;\n";
+ if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS)
+ {
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ }
+ conv_core(0);
+ for (int i = 1; i < conv_params.src_depth_loop_size; ++i)
+ {
+ read_src();
+ conv_core(i * block_size.w * 4);
+ c += " s += 1;\n";
+ }
+ if (conv_params.AreWeightsBuffer())
+ {
+ c += " filters_loc += " + std::to_string(local_mem_size) + ";\n";
+ }
+ c += " } while (s < args.src_tensor.Slices());\n";
+ if (!conv_params.x_kernel_is_1)
+ {
+ c += " };\n";
+ }
+ if (!conv_params.y_kernel_is_1)
+ {
+ c += " };\n";
+ }
+ if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1)
+ {
+ c += " };\n";
+ }
+ if (conv_params.AreWeightsBuffer())
+ {
+ if (conv_params.weights_upload_type == ConvPowerVR::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP)
+ {
+ c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "DST_S", block_size.w);
+ }
+ else if (conv_params.weights_upload_type ==
+ ConvPowerVR::WeightsUploadType::LOCAL_MEM_BY_THREADS)
+ {
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ c += GenerateUploadByThreads("weights_cache", "args.biases.GetPtr()", "DST_S", "lid",
+ total_work_items, block_size.w);
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ }
+ else
+ {
+ c += " weights_cache = args.biases.GetPtr() + DST_S;\n";
+ }
+ }
+ if (late_oob_check)
+ {
+ c += " if (" + dst_oob_check + ") {\n";
+ c += " return;\n";
+ c += " }\n";
+ }
+
+ auto generate_dst_check = [&](int x, int y, int z) {
+ std::string check;
+ const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
+ const std::vector<std::string> names{"Width()", "Height()", "Depth()"};
+ std::vector<std::string> coords(3);
+ coords[0] = "DST_X + " + std::to_string(x);
+ coords[1] = "DST_Y + " + std::to_string(y);
+ coords[2] = "DST_Z + " + std::to_string(z);
+ const std::vector<int> ids{x, y, z};
+ for (size_t i = 0; i < axes.size(); ++i)
+ {
+ const auto &axis = axes[i];
+ if (src_def.HasAxis(axis) && ids[i] != 0)
+ {
+ if (!check.empty())
+ {
+ check += " && ";
+ }
+ check += coords[i] + " < args.dst_tensor." + names[i];
+ }
+ }
+ return check;
+ };
+
+ for (int s = 0; s < block_size.w; ++s)
+ {
+ const std::string sind = std::to_string(s);
+ c += " if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n";
+ c += " {\n";
+ if (conv_params.AreWeightsBuffer())
+ {
+ c += " FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n";
+ }
+ else
+ {
+ c += " FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n";
+ }
+ for (int z = 0; z < block_size.z; ++z)
+ {
+ const std::string zind = std::to_string(z);
+ for (int y = 0; y < block_size.y; ++y)
+ {
+ const std::string yind = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x)
+ {
+ const std::string xind = std::to_string(x);
+ const std::string id = generate_id_full(xind, yind, zind, sind);
+ const std::string check = generate_dst_check(x, y, z);
+ std::string coords = "DST_X + " + xind + ", DST_Y + " + yind;
+ if (src_def.HasAxis(Axis::DEPTH))
+ {
+ coords += ", DST_Z + " + zind;
+ }
+ coords += ", DST_S + " + sind;
+ if (!check.empty())
+ {
+ c += " if (" + check + ") {\n";
+ }
+ else
+ {
+ c += " {\n";
+ }
+ c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
+ c += " args.dst_tensor.Write(res, " + coords + ");\n";
+ c += " }\n";
+ }
+ }
+ }
+ c += " }\n";
+ }
+ c += "}\n";
+ return c;
+}
+
+ConvPowerVR::ConvParams
+ConvPowerVR::GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ int src_depth, int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
+ bool different_weights_for_height, const BHWC *dst_shape)
+{
+ ConvParams conv_params;
+ conv_params.linear_spatial = false;
+ conv_params.weights_data_type = DeduceDataTypeFromPrecision(definition.precision);
+ conv_params.x_kernel_is_1 = x_kernel_is_1;
+ conv_params.y_kernel_is_1 = y_kernel_is_1;
+ conv_params.different_weights_for_height = different_weights_for_height;
+ if (device_info.IsNvidia())
+ {
+ if (different_weights_for_height)
+ {
+ work_group_size_ = int3(32, 1, 1);
+ work_group_launch_order_ = int3(2, 0, 1);
+ conv_params.fixed_work_group_size = true;
+ }
+ else
+ {
+ conv_params.linear_spatial = true;
+ work_group_size_ = int3(32, 1, 1);
+ work_group_launch_order_ = int3(1, 0, 2);
+ conv_params.fixed_work_group_size = true;
+ }
+ conv_params.block_size = int4(2, 1, 1, 4);
+ conv_params.src_depth_loop_size = 1;
+ conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
+ if (dst_depth % 4 == 0 || dst_depth >= 8)
+ {
+ conv_params.block_size.w = 4;
+ }
+ else if (dst_depth % 2 == 0 || dst_depth >= 4)
+ {
+ conv_params.block_size.w = 2;
+ }
+ else
+ {
+ conv_params.block_size.w = dst_depth;
+ }
+ if (dst_shape)
+ {
+ int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
+ float task_size_per_cu = static_cast<float>(task_size) / device_info.compute_units_count;
+ int block_size =
+ conv_params.block_size.x * conv_params.block_size.y * conv_params.block_size.w;
+ float threads_per_cu = task_size_per_cu / block_size;
+ float warps_per_cu = threads_per_cu / 32 /*warp_size*/;
+ if (warps_per_cu < 8.0f)
+ {
+ conv_params.block_size.x = 1;
+ }
+ if (warps_per_cu < 4.0f && conv_params.block_size.w >= 4)
+ {
+ conv_params.block_size.w /= 2;
+ }
+ if (warps_per_cu < 2.0f && conv_params.block_size.w >= 2)
+ {
+ conv_params.block_size.w /= 2;
+ }
+ }
+ if (src_depth % 2 == 0)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0 && conv_params.block_size.w <= 2)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ }
+ else if (device_info.IsPowerVR())
+ {
+ if (different_weights_for_height)
+ {
+ work_group_size_ = int3(32, 1, 1);
+ work_group_launch_order_ = int3(2, 0, 1);
+ conv_params.fixed_work_group_size = true;
+ }
+ else
+ {
+ conv_params.linear_spatial = true;
+ work_group_size_ = int3(32, 1, 1);
+ work_group_launch_order_ = int3(1, 0, 2);
+ conv_params.fixed_work_group_size = true;
+ }
+ conv_params.weights_data_type =
+ definition.precision == CalculationsPrecision::F16 ? DataType::FLOAT16 : DataType::FLOAT32;
+ conv_params.block_size = int4(1, 1, 1, 4);
+ conv_params.src_depth_loop_size = 1;
+ conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
+ if (dst_depth % 8 == 0 || dst_depth >= 32)
+ {
+ conv_params.block_size.w = 8;
+ }
+ else if (dst_depth % 4 == 0 || dst_depth >= 8)
+ {
+ conv_params.block_size.w = 4;
+ }
+ else if (dst_depth % 2 == 0 || dst_depth >= 4)
+ {
+ conv_params.block_size.w = 2;
+ }
+ else
+ {
+ conv_params.block_size.w = dst_depth;
+ }
+ if (definition.precision == CalculationsPrecision::F16)
+ {
+ conv_params.block_size.w = std::min(4, conv_params.block_size.w);
+ if (src_depth % 2 == 0)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0 && conv_params.block_size.w <= 2)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ if (conv_params.block_size.w == 1)
+ {
+ if (src_depth % 2 == 0)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ if (src_depth <= 8)
+ {
+ conv_params.src_depth_loop_size = src_depth;
+ }
+ }
+ conv_params.block_size.x = 2;
+ }
+ }
+ else if (device_info.IsAMD())
+ {
+ if (different_weights_for_height)
+ {
+ work_group_size_ = int3(32, 1, 1);
+ work_group_launch_order_ = int3(2, 0, 1);
+ conv_params.fixed_work_group_size = true;
+ }
+ else
+ {
+ work_group_size_ = int3(8, 4, 1);
+ work_group_launch_order_ = int3(2, 0, 1);
+ conv_params.fixed_work_group_size = true;
+ }
+
+ conv_params.block_size = int4(2, 1, 1, 1);
+ if (x_kernel_is_1 && y_kernel_is_1)
+ {
+ conv_params.block_size.y = 2;
+ }
+ conv_params.src_depth_loop_size = 1;
+ conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM;
+ if (dst_depth % 8 == 0 || dst_depth >= 32)
+ {
+ conv_params.block_size.w = 8;
+ }
+ else if (dst_depth % 4 == 0 || dst_depth >= 8)
+ {
+ conv_params.block_size.w = 4;
+ }
+ else if (dst_depth % 2 == 0 || dst_depth >= 4)
+ {
+ conv_params.block_size.w = 2;
+ }
+ else
+ {
+ conv_params.block_size.w = 1;
+ }
+ if (src_depth % 2 == 0 && src_depth >= 16)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ }
+ else if (device_info.IsMali())
+ {
+ int block_size = 2;
+ if (dst_shape)
+ {
+ int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
+ block_size = GetRecommendedBlockSizeForConv(device_info, definition.precision, task_size);
+ }
+ if (!x_kernel_is_1 || !y_kernel_is_1)
+ {
+ block_size = std::min(block_size, 4);
+ }
+ if (block_size == 8)
+ {
+ if (dst_depth == 1 || dst_depth == 3)
+ {
+ conv_params.block_size = int4(2, 2, 1, 1);
+ }
+ else
+ {
+ conv_params.block_size = int4(2, 2, 1, 2);
+ }
+ }
+ else if (block_size == 4)
+ {
+ if (dst_depth == 1 || dst_depth == 3)
+ {
+ conv_params.block_size = int4(2, 2, 1, 1);
+ }
+ else
+ {
+ conv_params.block_size = int4(2, 1, 1, 2);
+ }
+ }
+ else if (block_size == 2)
+ {
+ conv_params.block_size = int4(2, 1, 1, 1);
+ }
+ else
+ {
+ conv_params.block_size = int4(1, 1, 1, 1);
+ }
+ conv_params.src_depth_loop_size = 1;
+ MaliInfo mali_info = device_info.mali_info;
+ if (src_depth % 2 == 0 && block_size <= 2 && !mali_info.IsMidgard())
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0 && block_size == 1 && !mali_info.IsMidgard() &&
+ definition.precision == CalculationsPrecision::F16)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ work_group_size_ = int3(4, 4, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = false;
+ conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
+ }
+ else if (device_info.IsAdreno())
+ {
+ conv_params.block_size = int4(2, 2, 1, 2);
+ if (device_info.IsAdreno3xx())
+ {
+ if (definition.precision == CalculationsPrecision::F16)
+ {
+ conv_params.block_size = int4(2, 2, 1, 2);
+ }
+ else if (definition.precision == CalculationsPrecision::F32_F16)
+ {
+ conv_params.block_size = int4(2, 1, 1, 2);
+ }
+ else
+ { // F32
+ conv_params.block_size = int4(2, 2, 1, 1);
+ }
+ }
+ work_group_size_ = int3(8, 2, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = false;
+ conv_params.src_depth_loop_size = 1;
+ if (definition.src_tensors.size() == 2)
+ {
+ // dynamic weights supported only with buffers.
+ conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
+ }
+ else
+ {
+ conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM_X4;
+ }
+ }
+ else if (device_info.IsIntel())
+ {
+ if (different_weights_for_height)
+ {
+ work_group_size_ = int3(16, 1, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = true;
+ }
+ else
+ {
+ conv_params.linear_spatial = true;
+ work_group_size_ = int3(16, 1, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = true;
+ }
+ conv_params.block_size = int4(1, 1, 1, 4);
+ conv_params.src_depth_loop_size = 1;
+ int sub_group_size = 16;
+ const bool supports_subgroups = device_info.SupportsExtension("cl_khr_subgroups") ||
+ device_info.SupportsExtension("cl_intel_subgroups");
+ if (definition.precision != CalculationsPrecision::F32_F16 && supports_subgroups &&
+ device_info.SupportsExtension("cl_intel_required_subgroup_size") &&
+ device_info.SupportsSubGroupWithSize(sub_group_size))
+ {
+ conv_params.weights_upload_type = WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
+ conv_params.simd_size = sub_group_size;
+ }
+ else
+ {
+ conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
+ }
+ if (dst_depth % 4 == 0 || dst_depth >= 8)
+ {
+ conv_params.block_size.w = 4;
+ }
+ else if (dst_depth % 2 == 0 || dst_depth >= 4)
+ {
+ conv_params.block_size.w = 2;
+ }
+ else
+ {
+ conv_params.block_size.w = dst_depth;
+ }
+ if (src_depth % 2 == 0)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0 && conv_params.block_size.w <= 2)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ }
+ else
+ {
+ conv_params.block_size = int4(1, 1, 1, 4);
+ work_group_size_ = int3(8, 2, 1);
+ work_group_launch_order_ = int3(0, 1, 2);
+ conv_params.fixed_work_group_size = false;
+ conv_params.src_depth_loop_size = 1;
+ conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
+ if (dst_depth % 4 == 0 || dst_depth >= 8)
+ {
+ conv_params.block_size.w = 4;
+ }
+ else if (dst_depth % 2 == 0 || dst_depth >= 4)
+ {
+ conv_params.block_size.w = 2;
+ }
+ else
+ {
+ conv_params.block_size.w = dst_depth;
+ }
+ if (src_depth % 2 == 0)
+ {
+ conv_params.src_depth_loop_size = 2;
+ }
+ if (src_depth % 4 == 0 && conv_params.block_size.w <= 2)
+ {
+ conv_params.src_depth_loop_size = 4;
+ }
+ }
+
+ return conv_params;
+}
+
+ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
+ attr.dilations.w == 1 && attr.padding.prepended.w == 0 &&
+ attr.padding.appended.w == 0;
+ const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 &&
+ attr.dilations.h == 1 && attr.padding.prepended.h == 0 &&
+ attr.padding.appended.h == 0;
+ return GuessBestParams(device_info, definition, src_depth, dst_depth, x_kernel_is_1,
+ y_kernel_is_1, false, dst_shape);
+}
+
+ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution3DAttributes &attr,
+ const BHWDC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
+ attr.dilations.w == 1 && attr.padding.prepended.w == 0 &&
+ attr.padding.appended.w == 0;
+ const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 &&
+ attr.dilations.h == 1 && attr.padding.prepended.h == 0 &&
+ attr.padding.appended.h == 0;
+ const bool z_kernel_is_1 = attr.weights.shape.d == 1 && attr.strides.d == 1 &&
+ attr.dilations.d == 1 && attr.padding.prepended.d == 0 &&
+ attr.padding.appended.d == 0;
+
+ ConvPowerVR::ConvParams result;
+ BHWC shape;
+ if (dst_shape)
+ {
+ shape.b = dst_shape->b;
+ shape.h = dst_shape->h * dst_shape->d;
+ shape.w = dst_shape->w;
+ shape.c = dst_shape->c;
+ result = GuessBestParams(device_info, definition, src_depth, dst_depth, x_kernel_is_1,
+ y_kernel_is_1, false, &shape);
+ }
+ else
+ {
+ result = GuessBestParams(device_info, definition, src_depth, dst_depth, x_kernel_is_1,
+ y_kernel_is_1, false, nullptr);
+ }
+ result.z_kernel_is_1 = z_kernel_is_1;
+ return result;
+}
+
+ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape,
+ const BHWC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(weights_shape.b, 4);
+ const int src_depth = DivideRoundUp(weights_shape.c, 4);
+ const bool x_kernel_is_1 = weights_shape.w == 1 && attr.strides.w == 1 && attr.dilations.w == 1 &&
+ attr.padding.prepended.w == 0 && attr.padding.appended.w == 0;
+ const bool y_kernel_is_1 = weights_shape.h == 1 && attr.strides.h == 1 && attr.dilations.h == 1 &&
+ attr.padding.prepended.h == 0 && attr.padding.appended.h == 0;
+ return GuessBestParams(device_info, definition, src_depth, dst_depth, x_kernel_is_1,
+ y_kernel_is_1, false, dst_shape);
+}
+
+ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const FullyConnectedAttributes &attr,
+ const BHWC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ ConvPowerVR::ConvParams params =
+ GuessBestParams(device_info, definition, src_depth, dst_depth, true, true, false, dst_shape);
+ work_group_size_.x *= work_group_size_.y;
+ work_group_size_.y = 1;
+ params.block_size.x *= params.block_size.y;
+ params.block_size.y = 1;
+ return params;
+}
+
+ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape)
+{
+ const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
+ const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
+ ConvPowerVR::ConvParams params =
+ GuessBestParams(device_info, definition, src_depth, dst_depth, true, true, true, dst_shape);
+ params.block_size.x *= params.block_size.y;
+ params.block_size.y = 1;
+ return params;
+}
+
+ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *dst_shape)
+{
+ ConvPowerVR result(definition, attr, device_info, dst_shape);
+ result.GenerateCode(device_info);
+ result.UploadData(attr.weights, attr.bias);
+ return result;
+}
+
+ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info, const OperationDef &definition,
+ const FullyConnectedAttributes &attr, const BHWC *dst_shape)
+{
+ ConvPowerVR result(definition, attr, device_info, dst_shape);
+ result.GenerateCode(device_info);
+ result.UploadData(attr.weights, attr.bias);
+ return result;
+}
+
+ConvPowerVR CreateConvPowerVRDynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape, const BHWC *dst_shape)
+{
+ ConvPowerVR result(definition, attr, weights_shape, device_info, dst_shape);
+ result.GenerateCode(device_info);
+ result.UploadBias(attr.bias);
+ return result;
+}
+
+ConvPowerVR CreateConvPowerVRWino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape)
+{
+ ConvPowerVR result(definition);
+ result.conv_params_ = result.GuessBestParamsWinograd(device_info, definition, attr, dst_shape);
+ result.GenerateCode(device_info);
+ result.UploadDataForWinograd4x4To6x6(attr.weights);
+ return result;
+}
+
+ConvPowerVR CreateConvPowerVR3D(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution3DAttributes &attr, const BHWDC *dst_shape)
+{
+ ConvPowerVR result(definition, attr, device_info, dst_shape);
+ result.GenerateCode(device_info);
+ result.UploadWeights(attr.weights);
+ result.UploadBias(attr.bias);
+ return result;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_POWERVR_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_POWERVR_H__
+
+#include <cstring>
+#include <vector>
+
+#include "open_cl/Buffer.h"
+#include "open_cl/ClDevice.h"
+#include "open_cl/kernels/ConvCommon.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/LinearStorage.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Texture2d.h"
+#include "open_cl/Util.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+#include "open_cl/WinogradUtil.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ConvPowerVR : public GPUOperation
+{
+public:
+ ConvPowerVR() = default;
+ void GetPossibleKernelWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const override;
+ absl::Status BindArguments(ArgumentsBinder *args) override;
+ int3 GetGridSize() const override;
+
+ ConvWeightsDescription GetConvWeightsDescription() const
+ {
+ ConvWeightsDescription desc;
+ desc.layout = ConvWeightsLayout::kOHWIOGroupI4O4;
+ desc.output_group_size = conv_params_.block_size.w;
+ return desc;
+ }
+
+ // Move only
+ ConvPowerVR(ConvPowerVR &&operation);
+ ConvPowerVR &operator=(ConvPowerVR &&operation);
+ ConvPowerVR(const ConvPowerVR &) = delete;
+ ConvPowerVR &operator=(const ConvPowerVR &) = delete;
+
+private:
+ enum class WeightsUploadType
+ {
+ LOCAL_MEM_ASYNC_SUBGROUP, // we use it for PowerVR with workgroup size = 32
+ LOCAL_MEM_BY_THREADS,
+ GLOBAL_MEM,
+ CONSTANT_MEM,
+ PRIVATE_MEM_SIMD_BROADCAST,
+ TEXTURES_MEM_X4, // 4 textures for weights
+ };
+
+ struct ConvParams
+ {
+ // Usually we use this combinations for CalculationPrecision:
+ // F32: all F32
+ // F16: all F16
+ // F32_F16: all besides accumulator is F16, including weights
+ // But for PowerVR we can achieve better performance in F32_F16 with F32
+ // weights, so for PowerVR in this kernel we have F32 weights for
+ // F32_F16 precision mode
+ DataType weights_data_type; // used for weights and biases
+ int4 block_size; // WHDS
+ bool fixed_work_group_size;
+ bool linear_spatial; // spatial dimensions are Width/Height/Depth
+ bool different_weights_for_height;
+ int src_depth_loop_size;
+ WeightsUploadType weights_upload_type;
+ bool x_kernel_is_1;
+ bool y_kernel_is_1;
+ bool z_kernel_is_1;
+
+ // used only with PRIVATE_MEM_SIMD_BROADCAST
+ int simd_size = 1;
+
+ bool AreWeightsBuffer() const
+ {
+ return weights_upload_type != WeightsUploadType::TEXTURES_MEM_X4;
+ }
+
+ bool IsPrivateMemBroadcast() const
+ {
+ return weights_upload_type == WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
+ }
+ };
+
+ ConvPowerVR(const OperationDef &definition, const Convolution2DAttributes &attr,
+ const DeviceInfo &device_info, const BHWC *dst_shape = nullptr);
+ ConvPowerVR(const OperationDef &definition, const Convolution2DAttributes &attr,
+ const BHWC &weights_shape, const DeviceInfo &device_info,
+ const BHWC *dst_shape = nullptr);
+ ConvPowerVR(const OperationDef &definition, const FullyConnectedAttributes &attr,
+ const DeviceInfo &device_info, const BHWC *dst_shape = nullptr);
+ explicit ConvPowerVR(const OperationDef &definition);
+ ConvPowerVR(const OperationDef &definition, const Convolution3DAttributes &attr,
+ const DeviceInfo &device_info, const BHWDC *dst_shape = nullptr);
+
+ void GenerateCode(const DeviceInfo &device_info);
+
+ template <DataType T>
+ void UploadData(const InternalTensor<OHWI, T> &weights, const InternalTensor<Linear, T> &biases);
+ template <DataType T> void UploadDataForWinograd4x4To6x6(const InternalTensor<OHWI, T> &weights);
+
+ template <DataType T> void UploadWeights(const InternalTensor<OHWI, T> &weights);
+
+ template <DataType T> void UploadWeights(const InternalTensor<OHWDI, T> &weights);
+
+ template <DataType T> void UploadBias(const InternalTensor<Linear, T> &bias);
+
+ friend ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *dst_shape);
+
+ friend ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const FullyConnectedAttributes &attr, const BHWC *dst_shape);
+
+ friend ConvPowerVR CreateConvPowerVRDynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape,
+ const BHWC *dst_shape);
+
+ friend ConvPowerVR CreateConvPowerVRWino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape);
+
+ friend ConvPowerVR CreateConvPowerVR3D(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution3DAttributes &attr,
+ const BHWDC *dst_shape);
+
+ ConvParams GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *dst_shape = nullptr);
+ ConvParams GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC &weights_shape,
+ const BHWC *dst_shape = nullptr);
+ ConvParams GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ const FullyConnectedAttributes &attr, const BHWC *dst_shape = nullptr);
+ ConvParams GuessBestParamsWinograd(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape = nullptr);
+ ConvParams GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution3DAttributes &attr, const BHWDC *dst_shape = nullptr);
+ ConvParams GuessBestParams(const DeviceInfo &device_info, const OperationDef &definition,
+ int src_depth, int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
+ bool different_weights_for_height, const BHWC *dst_shape = nullptr);
+
+ std::string GenerateConv(const DeviceInfo &device_info, const OperationDef &op_def,
+ bool stride_correction, const ConvParams &conv_params);
+
+ int4 stride_;
+ int4 padding_;
+ int4 kernel_size_;
+ int4 dilation_;
+ ConvParams conv_params_;
+};
+
+template <DataType T>
+void ConvPowerVR::UploadData(const InternalTensor<OHWI, T> &weights,
+ const InternalTensor<Linear, T> &biases)
+{
+ UploadWeights(weights);
+ UploadBias(biases);
+}
+
+template <DataType T>
+void ConvPowerVR::UploadDataForWinograd4x4To6x6(const InternalTensor<OHWI, T> &weights)
+{
+ InternalTensor<OHWI, T> wino_weights;
+ RearrangeWeightsToWinograd4x4To6x6Weights(weights, &wino_weights);
+ UploadWeights(wino_weights);
+ InternalTensor<Linear, DataType::FLOAT32> biases;
+ biases.shape = Linear(weights.shape.o);
+ biases.data.resize(weights.shape.o, 0.0f);
+ UploadBias(biases);
+}
+
+template <DataType T> void ConvPowerVR::UploadBias(const InternalTensor<Linear, T> &bias)
+{
+ BufferDescriptor desc;
+ desc.element_type = conv_params_.weights_data_type;
+ desc.element_size = 4;
+ desc.memory_type =
+ conv_params_.weights_upload_type == ConvPowerVR::WeightsUploadType::CONSTANT_MEM
+ ? MemoryType::CONSTANT
+ : MemoryType::GLOBAL;
+ const int float_size = sizeof(float);
+ // TODO
+ // conv_params_.weights_data_type == DataType::FLOAT32 ? sizeof(float) : sizeof(half);
+ int aligned_channels = AlignByN(bias.shape.v, 4 * conv_params_.block_size.w);
+ desc.size = float_size * aligned_channels;
+ desc.data.resize(desc.size);
+ if (conv_params_.weights_data_type == DataType::FLOAT32)
+ {
+ float *gpu_data = reinterpret_cast<float *>(desc.data.data());
+ for (int i = 0; i < aligned_channels; ++i)
+ {
+ gpu_data[i] = i < bias.shape.v ? bias.data[i] : 0.0f;
+ }
+ }
+ // else
+ // {
+ // half *gpu_data = reinterpret_cast<half *>(desc.data.data());
+ // for (int i = 0; i < aligned_channels; ++i)
+ // {
+ // gpu_data[i] = i < bias.shape.v ? bias.data[i] : 0.0f;
+ // }
+ // }
+ args_.AddObject("biases", absl::make_unique<BufferDescriptor>(std::move(desc)));
+}
+
+template <DataType T> void ConvPowerVR::UploadWeights(const InternalTensor<OHWI, T> &weights)
+{
+ const int dst_slices = AlignByN(DivideRoundUp(weights.shape.o, 4), conv_params_.block_size.w);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+
+ const bool f32_weights = conv_params_.weights_data_type == DataType::FLOAT32;
+ const int float4_size = sizeof(float4);
+ // TODO
+ // f32_weights ? sizeof(float4) : sizeof(half4);
+
+ const int elements_count = weights.shape.h * weights.shape.w * src_slices * dst_slices * 4;
+
+ std::vector<uint8_t> data(float4_size * elements_count);
+
+ if (f32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(data.data());
+ if (conv_params_.AreWeightsBuffer())
+ {
+ RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
+ absl::MakeSpan(ptr, elements_count));
+ }
+ else
+ {
+ RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
+ absl::MakeSpan(ptr, elements_count));
+ }
+ }
+ // else
+ // {
+ // half4 *ptr = reinterpret_cast<half4 *>(data.data());
+ // if (conv_params_.AreWeightsBuffer())
+ // {
+ // RearrangeWeightsToOHWIOGroupI4O4(weights, conv_params_.block_size.w,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+ // else
+ // {
+ // RearrangeWeightsToI4HWIOOGroupO4(weights, conv_params_.block_size.w,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+ // }
+ if (conv_params_.AreWeightsBuffer())
+ {
+ BufferDescriptor desc;
+ desc.element_type = conv_params_.weights_data_type;
+ desc.element_size = 4;
+ desc.memory_type =
+ conv_params_.weights_upload_type == ConvPowerVR::WeightsUploadType::CONSTANT_MEM
+ ? MemoryType::CONSTANT
+ : MemoryType::GLOBAL;
+ desc.size = float4_size * elements_count;
+ desc.data = std::move(data);
+ args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc)));
+ }
+ else
+ {
+ const int texture_width = dst_slices;
+ const int texture_height = src_slices * weights.shape.h * weights.shape.w;
+ const int sub_size = float4_size * texture_width * texture_height;
+ for (int i = 0; i < 4; ++i)
+ {
+ Texture2DDescriptor desc;
+ desc.element_type = conv_params_.weights_data_type;
+ desc.size = int2(texture_width, texture_height);
+ desc.data.resize(sub_size);
+ std::memcpy(desc.data.data(), data.data() + sub_size * i, sub_size);
+ const std::string name = "weights" + std::to_string(i);
+ args_.AddObject(name, absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+ }
+ }
+}
+
+template <DataType T> void ConvPowerVR::UploadWeights(const InternalTensor<OHWDI, T> &weights)
+{
+ const int block_size = conv_params_.block_size.w;
+ const int dst_slices = AlignByN(DivideRoundUp(weights.shape.o, 4), block_size);
+ const int src_slices = DivideRoundUp(weights.shape.i, 4);
+
+ const int elements_count =
+ weights.shape.d * weights.shape.h * weights.shape.w * src_slices * dst_slices * 4;
+ const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
+
+ const int float4_size = f32_weights ? 16 : 8;
+
+ std::vector<uint8_t> data(float4_size * elements_count);
+
+ if (f32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(data.data());
+ if (conv_params_.AreWeightsBuffer())
+ {
+ RearrangeWeightsToODHWIOGroupI4O4(weights, conv_params_.block_size.w,
+ absl::MakeSpan(ptr, elements_count));
+ }
+ else
+ {
+ RearrangeWeightsToI4DHWIOOGroupO4(weights, conv_params_.block_size.w,
+ absl::MakeSpan(ptr, elements_count));
+ }
+ }
+ // else
+ // {
+ // half4 *ptr = reinterpret_cast<half4 *>(data.data());
+ // if (conv_params_.AreWeightsBuffer())
+ // {
+ // RearrangeWeightsToODHWIOGroupI4O4(weights, conv_params_.block_size.w,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+ // else
+ // {
+ // RearrangeWeightsToI4DHWIOOGroupO4(weights, conv_params_.block_size.w,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+ // }
+
+ if (conv_params_.AreWeightsBuffer())
+ {
+ BufferDescriptor desc;
+ desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 4;
+ desc.size = float4_size * elements_count;
+ desc.data = std::move(data);
+ args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc)));
+ }
+ else
+ {
+ const int texture_width = dst_slices;
+ const int texture_height = src_slices * weights.shape.d * weights.shape.h * weights.shape.w;
+ int sub_size = float4_size * texture_width * texture_height;
+ for (int i = 0; i < 4; ++i)
+ {
+ Texture2DDescriptor desc;
+ desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.size = int2(texture_width, texture_height);
+ desc.data.resize(sub_size);
+ memcpy(desc.data.data(), data.data() + sub_size * i, sub_size);
+ const std::string name = "weights" + std::to_string(i);
+ args_.AddObject(name, absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+ }
+ }
+}
+
+ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution2DAttributes &attr, const BHWC *dst_shape = nullptr);
+
+ConvPowerVR CreateConvPowerVR(const DeviceInfo &device_info, const OperationDef &definition,
+ const FullyConnectedAttributes &attr,
+ const BHWC *dst_shape = nullptr);
+
+ConvPowerVR CreateConvPowerVRDynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC &weights_shape,
+ const BHWC *dst_shape = nullptr);
+
+ConvPowerVR CreateConvPowerVRWino4x4To6x6(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const Convolution2DAttributes &attr,
+ const BHWC *dst_shape = nullptr);
+
+ConvPowerVR CreateConvPowerVR3D(const DeviceInfo &device_info, const OperationDef &definition,
+ const Convolution3DAttributes &attr,
+ const BHWDC *dst_shape = nullptr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_POWERVR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "open_cl/kernels/ConvWeightsConverter.h"
+
+#include <string>
+
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+ConverterToConvWeights::ConverterToConvWeights(const OperationDef &definition,
+ const ConvWeightsDescription &conv_weights_desc)
+ : GPUOperation(definition), conv_weights_desc_(conv_weights_desc)
+{
+ code_ = GetConverterToConvWeightsCode(definition_, conv_weights_desc_);
+}
+
+ConverterToConvWeights::ConverterToConvWeights(ConverterToConvWeights &&operation)
+ : GPUOperation(std::move(operation)), conv_weights_desc_(operation.conv_weights_desc_)
+{
+}
+
+ConverterToConvWeights &ConverterToConvWeights::operator=(ConverterToConvWeights &&operation)
+{
+ if (this != &operation)
+ {
+ conv_weights_desc_ = operation.conv_weights_desc_;
+ GPUOperation::operator=(std::move(operation));
+ }
+ return *this;
+}
+
+std::string ConverterToConvWeights::GetConverterToConvWeightsCode(
+ const OperationDef &op_def, const ConvWeightsDescription &conv_weights_desc)
+{
+ AddSrcTensor("src_tensor", op_def.src_tensors[0]);
+ AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
+ args_.AddFloat("mask_x");
+ args_.AddFloat("mask_y");
+ args_.AddFloat("mask_z");
+ args_.AddFloat("mask_w");
+
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int GROUP_SIZE = " + std::to_string(conv_weights_desc.output_group_size) + ";\n";
+ c += " int O = get_global_id(0) * 4;\n";
+ c += " int I = get_global_id(1);\n";
+ c += " int Z = get_global_id(2);\n";
+ c += " int W = Z % args.src_tensor.Width();\n";
+ c += " int H = Z / args.src_tensor.Width();\n";
+ c += " if (O >= args.src_tensor.Batch() || I >= args.src_tensor.Slices() || "
+ "H >= args.src_tensor.Height()) return;\n";
+ c += " FLT4 v0 = args.src_tensor.Read(W, H, I, O + 0);\n";
+ c += " FLT4 v1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ c += " FLT4 v2 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ c += " FLT4 v3 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ c += " if (O + 1 < args.src_tensor.Batch()) {\n";
+ c += " v1 = args.src_tensor.Read(W, H, I, O + 1);\n";
+ c += " }\n";
+ c += " if (O + 2 < args.src_tensor.Batch()) {\n";
+ c += " v2 = args.src_tensor.Read(W, H, I, O + 2);\n";
+ c += " }\n";
+ c += " if (O + 3 < args.src_tensor.Batch()) {\n";
+ c += " v3 = args.src_tensor.Read(W, H, I, O + 3);\n";
+ c += " }\n";
+ c += " if (I == args.src_tensor.Slices() - 1) {\n";
+ c += " FLT4 mask = (FLT4)(args.mask_x, args.mask_y, args.mask_z, "
+ "args.mask_w);\n";
+ c += " v0 *= mask;\n";
+ c += " v1 *= mask;\n";
+ c += " v2 *= mask;\n";
+ c += " v3 *= mask;\n";
+ c += " }\n";
+ c += " FLT4 r0 = (FLT4)(v0.x, v1.x, v2.x, v3.x);\n";
+ c += " FLT4 r1 = (FLT4)(v0.y, v1.y, v2.y, v3.y);\n";
+ c += " FLT4 r2 = (FLT4)(v0.z, v1.z, v2.z, v3.z);\n";
+ c += " FLT4 r3 = (FLT4)(v0.w, v1.w, v2.w, v3.w);\n";
+ c += " int d_index = O / (GROUP_SIZE * 4);\n";
+ c += " int k_index = (O % (GROUP_SIZE * 4)) / 4;\n";
+ c += " int dst_offset = (((d_index * args.src_tensor.Height() + H) * "
+ "args.src_tensor.Width() + W) * "
+ "args.src_tensor.Slices() + I) * GROUP_SIZE + "
+ "k_index;\n";
+ c += " int address0 = dst_offset * 4 + 0;\n";
+ c += " int address1 = dst_offset * 4 + 1;\n";
+ c += " int address2 = dst_offset * 4 + 2;\n";
+ c += " int address3 = dst_offset * 4 + 3;\n";
+ c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0)\n;";
+ c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1)\n;";
+ c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2)\n;";
+ c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3)\n;";
+ c += "}\n";
+ return c;
+}
+
+absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder *args)
+{
+ float4 mask = GetMaskForLastPlane(src_[0]->Channels());
+ RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
+ RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
+ RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
+ return args->SetFloat("mask_w", mask.w);
+}
+
+int3 ConverterToConvWeights::GetGridSize() const
+{
+ const int grid_x =
+ DivideRoundUp(AlignByN(src_[0]->Batch(), 4 * conv_weights_desc_.output_group_size), 4);
+ const int grid_y = src_[0]->Slices();
+ const int grid_z = src_[0]->Width() * src_[0]->Height();
+ return int3(grid_x, grid_y, grid_z);
+}
+
+ConverterToConvWeights CreateConverterToConvWeights(const OperationDef &definition,
+ const ConvWeightsDescription &conv_weights_desc)
+{
+ return ConverterToConvWeights(definition, conv_weights_desc);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_WEIGHTS_CONVERTER_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_WEIGHTS_CONVERTER_H__
+
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/kernels/ConvCommon.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/Status.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class ConverterToConvWeights : public GPUOperation
+{
+public:
+ ConverterToConvWeights(const OperationDef &definition,
+ const ConvWeightsDescription &conv_weights_desc);
+ absl::Status BindArguments(ArgumentsBinder *args) override;
+ int3 GetGridSize() const override;
+
+ // Move only
+ ConverterToConvWeights(ConverterToConvWeights &&operation);
+ ConverterToConvWeights &operator=(ConverterToConvWeights &&operation);
+ ConverterToConvWeights(const ConverterToConvWeights &) = delete;
+ ConverterToConvWeights &operator=(const ConverterToConvWeights &) = delete;
+
+private:
+ std::string GetConverterToConvWeightsCode(const OperationDef &op_def,
+ const ConvWeightsDescription &conv_weights_desc);
+
+ ConvWeightsDescription conv_weights_desc_;
+};
+
+// We expect src BHWC tensor and we assume that B is O, H = H, W = W, C is I
+// as dst we expect Tensor with storage type BUFFER and
+// dst.b * dst.h * dst.w * dst.c = AlignByN(src.b, 4) * src.h * src.w
+// AlignByN(src.c, 4)
+ConverterToConvWeights
+CreateConverterToConvWeights(const OperationDef &definition,
+ const ConvWeightsDescription &conv_weights_desc);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONV_WEIGHTS_CONVERTER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Converter.h"
+
+#include <algorithm>
+#include <array>
+#include <string>
+
+#include "open_cl/Arguments.h"
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/ClErrors.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/Precision.h"
+#include "open_cl/InternalTensor.h"
+#include "open_cl/TensorType.h"
+#include "open_cl/TensorTypeUtil.h"
+#include "open_cl/Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+class OpenClConverterImpl : public TensorObjectConverter
+{
+public:
+ virtual absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) = 0;
+
+protected:
+ absl::Status DispatchKernel(cl_mem buffer_mem, Tensor *tensor)
+ {
+ kernel_.ResetBindingCounter();
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(buffer_mem));
+ RETURN_IF_ERROR(args_.SetObjectRef("tensor", tensor));
+ RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter()));
+ const int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(), tensor->Slices());
+ const int3 work_group_size = {16, 8, 1};
+ const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size);
+ return queue_->Dispatch(kernel_, work_groups_count, work_group_size);
+ }
+
+ Arguments args_;
+ BHWC shape_;
+ CLKernel kernel_;
+ TensorDescriptor tensor_descriptor_;
+ CLCommandQueue *queue_ = nullptr;
+ const CLContext *context_ = nullptr;
+};
+
+bool IsSupportedDataType(DataType type)
+{
+ return type == DataType::FLOAT16 || type == DataType::FLOAT32;
+}
+
+bool IsBHWCOpenCLBuffer(const ObjectDef &def)
+{
+ return IsSupportedDataType(def.data_type) && def.object_type == ObjectType::OPENCL_BUFFER &&
+ def.data_layout == DataLayout::BHWC;
+}
+
+bool IsOpenCLTensor(const ObjectDef &def)
+{
+ const bool is_buffer_tensor =
+ def.object_type == ObjectType::OPENCL_BUFFER && def.data_layout == DataLayout::DHWC4;
+ const bool is_image2d_tensor =
+ def.object_type == ObjectType::OPENCL_TEXTURE && def.data_layout == DataLayout::HDWC4;
+ const bool is_image2d_array_tensor =
+ def.object_type == ObjectType::OPENCL_TEXTURE && def.data_layout == DataLayout::DHWC4;
+ const bool is_single_image_tensor =
+ def.object_type == ObjectType::OPENCL_TEXTURE && def.data_layout == DataLayout::BHWC;
+ return IsSupportedDataType(def.data_type) && (is_buffer_tensor || is_image2d_tensor ||
+ is_image2d_array_tensor || is_single_image_tensor);
+}
+
+absl::Status GetOpenCLMemory(const TensorObject &obj, cl_mem *memory)
+{
+ auto texture = absl::get_if<OpenClTexture>(&obj);
+ auto buffer = absl::get_if<OpenClBuffer>(&obj);
+ if (texture && texture->memobj)
+ {
+ *memory = texture->memobj;
+ }
+ else if (buffer && buffer->memobj)
+ {
+ *memory = buffer->memobj;
+ }
+ else
+ {
+ return absl::InvalidArgumentError("Missing OpenCL object.");
+ }
+ return absl::OkStatus();
+}
+
+// Implements conversion from OpenCL tensor to another OpenCL tensor.
+class TensorToTensorConverter : public OpenClConverterImpl
+{
+public:
+ static bool IsSupported(const ObjectDef &input, const ObjectDef &output)
+ {
+ return IsOpenCLTensor(input) && IsOpenCLTensor(output);
+ }
+
+ absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) final
+ {
+ src_tensor_descriptor_.layout = Layout::BHWC;
+ src_tensor_descriptor_.storage_type =
+ ToTensorStorageType(input_def.object_def.object_type, input_def.object_def.data_layout);
+ src_tensor_descriptor_.data_type = input_def.object_def.data_type;
+ args_.AddObjectRef("src_tensor", AccessType::READ,
+ absl::make_unique<TensorDescriptor>(src_tensor_descriptor_));
+
+ dst_tensor_descriptor_.layout = Layout::BHWC;
+ dst_tensor_descriptor_.storage_type =
+ ToTensorStorageType(output_def.object_def.object_type, output_def.object_def.data_layout);
+ dst_tensor_descriptor_.data_type = output_def.object_def.data_type;
+ args_.AddObjectRef("dst_tensor", AccessType::WRITE,
+ absl::make_unique<TensorDescriptor>(dst_tensor_descriptor_));
+
+ const bool need_fp16_support = input_def.object_def.data_type == DataType::FLOAT16 ||
+ output_def.object_def.data_type == DataType::FLOAT16;
+ const std::string out_data_type = ToCLDataType(output_def.object_def.data_type);
+ std::string shader_src;
+ if (need_fp16_support)
+ {
+ shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+ }
+ shader_src +=
+ R"(__kernel void tensor_to_tensor($0) {
+ int linear_id = get_global_id(0);
+ int x = linear_id / args.dst_tensor.Batch();
+ int b = linear_id % args.dst_tensor.Batch();
+ int y = get_global_id(1);
+ int d = get_global_id(2);
+ if (x >= args.dst_tensor.Width() || y >= args.dst_tensor.Height() || d >= args.dst_tensor.Slices()) return;
+)";
+ shader_src +=
+ " " + out_data_type + "4 input = args.src_tensor.Read<" + out_data_type + ">(x, y, d, b);\n";
+ shader_src += " args.dst_tensor.Write(input, x, y, d, b);\n}";
+ queue_ = environment->queue();
+ context_ = &environment->context();
+ shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, input_def.dimensions.w,
+ input_def.dimensions.c);
+ RETURN_IF_ERROR(args_.TransformToCLCode(environment->device().info_, {}, &shader_src));
+ return environment->program_cache()->GetOrCreateCLKernel(
+ shader_src, "tensor_to_tensor", environment->context(), environment->device(), &kernel_);
+ }
+
+ absl::Status Convert(const TensorObject &input_obj, const TensorObject &output_obj) override
+ {
+ cl_mem in_memory = nullptr;
+ RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory));
+ cl_mem out_memory = nullptr;
+ RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
+
+ Tensor src_tensor;
+ RETURN_IF_ERROR(
+ CreateSharedTensor(*context_, in_memory, shape_, src_tensor_descriptor_, &src_tensor));
+ Tensor dst_tensor;
+ RETURN_IF_ERROR(
+ CreateSharedTensor(*context_, out_memory, shape_, dst_tensor_descriptor_, &dst_tensor));
+
+ RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", &src_tensor));
+ RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", &dst_tensor));
+
+ RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
+ const int3 grid =
+ int3(dst_tensor.Width() * dst_tensor.Batch(), dst_tensor.Height(), dst_tensor.Slices());
+ const int3 work_group_size = {16, 8, 1};
+ const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size);
+ return queue_->Dispatch(kernel_, work_groups_count, work_group_size);
+ }
+
+private:
+ TensorDescriptor src_tensor_descriptor_;
+ TensorDescriptor dst_tensor_descriptor_;
+};
+
+// Implements conversion from OpenCL-specific tensor layout to BHWC OpenCL
+// buffer.
+class TensorToBHWCBufferConverter : public OpenClConverterImpl
+{
+public:
+ static bool IsSupported(const ObjectDef &input, const ObjectDef &output)
+ {
+ return IsOpenCLTensor(input) && IsBHWCOpenCLBuffer(output);
+ }
+
+ absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) final
+ {
+ TensorStorageType src_tensor_type =
+ ToTensorStorageType(input_def.object_def.object_type, input_def.object_def.data_layout);
+ tensor_descriptor_.layout = Layout::BHWC;
+ tensor_descriptor_.storage_type = src_tensor_type;
+ tensor_descriptor_.data_type = input_def.object_def.data_type;
+ args_.AddObjectRef("tensor", AccessType::READ,
+ absl::make_unique<TensorDescriptor>(tensor_descriptor_));
+
+ const bool need_fp16_support = input_def.object_def.data_type == DataType::FLOAT16 ||
+ output_def.object_def.data_type == DataType::FLOAT16;
+ std::string shader_src;
+ if (need_fp16_support)
+ {
+ shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+ }
+ const std::string out_data_type = ToCLDataType(output_def.object_def.data_type);
+ shader_src += "__kernel void tensor_to_bhwc(";
+ shader_src += "__global " + out_data_type + "* dst, $0) {\n";
+ shader_src += R"( int linear_id = get_global_id(0);
+ int x = linear_id / args.tensor.Batch();
+ int b = linear_id % args.tensor.Batch();
+ int y = get_global_id(1);
+ int d = get_global_id(2);
+ if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return;
+)";
+ shader_src +=
+ " " + out_data_type + "4 input = args.tensor.Read<" + out_data_type + ">(x, y, d, b);\n";
+ shader_src += R"( int c = d * 4;
+ int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c;
+
+ dst[index] = input.x;
+ if (c + 1 < args.tensor.Channels()) {
+ dst[index + 1] = input.y;
+ }
+ if (c + 2 < args.tensor.Channels()) {
+ dst[index + 2] = input.z;
+ }
+ if (c + 3 < args.tensor.Channels()) {
+ dst[index + 3] = input.w;
+ }
+})";
+ queue_ = environment->queue();
+ context_ = &environment->context();
+ shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, input_def.dimensions.w,
+ input_def.dimensions.c);
+ RETURN_IF_ERROR(args_.TransformToCLCode(environment->device().info_, {}, &shader_src));
+ return environment->program_cache()->GetOrCreateCLKernel(
+ shader_src, "tensor_to_bhwc", environment->context(), environment->device(), &kernel_);
+ }
+
+ absl::Status Convert(const TensorObject &input_obj, const TensorObject &output_obj) override
+ {
+ auto output = absl::get_if<OpenClBuffer>(&output_obj);
+ if (!output || !output->memobj)
+ {
+ return absl::InvalidArgumentError("Missing output in tensor_to_bhwc converter");
+ }
+
+ cl_mem in_memory = nullptr;
+ RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory));
+ Tensor tensor;
+ RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_, tensor_descriptor_, &tensor));
+ return DispatchKernel(output->memobj, &tensor);
+ }
+};
+
+// Implements conversion from BHWC OpenCL buffer to OpenCL-specific tensor
+// layout.
+class BHWCBufferToTensorConverter : public OpenClConverterImpl
+{
+public:
+ static bool IsSupported(const ObjectDef &input, const ObjectDef &output)
+ {
+ return IsBHWCOpenCLBuffer(input) && IsOpenCLTensor(output);
+ }
+
+ std::pair<std::string, std::string> GetFromBhwcKernel(const TensorObjectDef &input_def,
+ const TensorObjectDef &) const
+ {
+ return std::make_pair("__global " + ToCLDataType(input_def.object_def.data_type) + "* src",
+ R"(int c = d * 4;
+ int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c;
+ result.x = src[index];
+ result.y = c + 1 < args.tensor.Channels() ? src[index + 1] : 1;
+ result.z = c + 2 < args.tensor.Channels() ? src[index + 2] : 2;
+ result.w = c + 3 < args.tensor.Channels() ? src[index + 3] : 3;
+)");
+ }
+
+ absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) final
+ {
+ auto params_kernel = GetFromBhwcKernel(input_def, output_def);
+
+ TensorStorageType dst_tensor_type =
+ ToTensorStorageType(output_def.object_def.object_type, output_def.object_def.data_layout);
+ tensor_descriptor_.layout = Layout::BHWC;
+ tensor_descriptor_.storage_type = dst_tensor_type;
+ tensor_descriptor_.data_type = output_def.object_def.data_type;
+ args_.AddObjectRef("tensor", AccessType::WRITE,
+ absl::make_unique<TensorDescriptor>(tensor_descriptor_));
+
+ const bool need_fp16_support = input_def.object_def.data_type == DataType::FLOAT16 ||
+ output_def.object_def.data_type == DataType::FLOAT16;
+ std::string shader_src;
+ if (need_fp16_support)
+ {
+ shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+ }
+ const std::string in_data_type = ToCLDataType(input_def.object_def.data_type);
+ const std::string out_data_type = ToCLDataType(output_def.object_def.data_type);
+ shader_src += "__kernel void bhwc_to_tensor(";
+ shader_src += "__global " + in_data_type + "* src, $0) {\n";
+
+ shader_src += R"( int linear_id = get_global_id(0);
+ int x = linear_id / args.tensor.Batch();
+ int b = linear_id % args.tensor.Batch();
+ int y = get_global_id(1);
+ int d = get_global_id(2);
+
+ if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return;
+)";
+ shader_src += " " + out_data_type + "4 result;\n";
+ shader_src += R"( int c = d * 4;
+ int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c;
+ result.x = src[index];
+ result.y = c + 1 < args.tensor.Channels() ? src[index + 1] : 1;
+ result.z = c + 2 < args.tensor.Channels() ? src[index + 2] : 2;
+ result.w = c + 3 < args.tensor.Channels() ? src[index + 3] : 3;
+)";
+ shader_src += " args.tensor.Write(result, x, y, d, b);\n}";
+ queue_ = environment->queue();
+ context_ = &environment->context();
+ shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, output_def.dimensions.w,
+ output_def.dimensions.c);
+ RETURN_IF_ERROR(args_.TransformToCLCode(environment->device().info_, {}, &shader_src));
+ return environment->program_cache()->GetOrCreateCLKernel(
+ shader_src, "bhwc_to_tensor", environment->context(), environment->device(), &kernel_);
+ }
+
+ absl::Status Convert(const TensorObject &input_obj, const TensorObject &output_obj) override
+ {
+ auto input = absl::get_if<OpenClBuffer>(&input_obj);
+ if (!input || !input->memobj)
+ {
+ return absl::InvalidArgumentError("Missing input in bhwc_to_tensor converter");
+ }
+ cl_mem out_memory = nullptr;
+ RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory));
+ Tensor tensor;
+ RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_, tensor_descriptor_, &tensor));
+ return DispatchKernel(input->memobj, &tensor);
+ }
+};
+
+std::array<size_t, 3> CalculateTextureRegion(const TensorObjectDef &def)
+{
+ const auto &dims = def.dimensions;
+ std::array<size_t, 3> region = {0, 0, 1};
+ switch (ToTensorStorageType(def.object_def.object_type, def.object_def.data_layout))
+ {
+ case TensorStorageType::SINGLE_TEXTURE_2D:
+ region[0] = static_cast<size_t>(dims.w * dims.b);
+ region[1] = static_cast<size_t>(dims.h);
+ break;
+ case TensorStorageType::TEXTURE_2D:
+ region[0] = static_cast<size_t>(dims.w * dims.b);
+ region[1] = static_cast<size_t>(dims.h * dims.d());
+ break;
+ case TensorStorageType::TEXTURE_ARRAY:
+ region[0] = static_cast<size_t>(dims.w * dims.b);
+ region[1] = static_cast<size_t>(dims.h);
+ region[2] = static_cast<size_t>(dims.d());
+ break;
+ default:
+ break;
+ }
+ return region;
+}
+
+bool IsOpenClTextureOrBuffer(ObjectType type)
+{
+ return type == ObjectType::OPENCL_BUFFER || type == ObjectType::OPENCL_TEXTURE;
+}
+
+// Copies data from one object of the same type and layout to another object.
+class TrivialCopier : public OpenClConverterImpl
+{
+public:
+ static bool IsSupported(const ObjectDef &input, const ObjectDef &output)
+ {
+ return IsOpenClTextureOrBuffer(input.object_type) && input.data_type == output.data_type &&
+ input.object_type == output.object_type && input.data_layout == output.data_layout;
+ }
+
+ absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) final
+ {
+ shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, input_def.dimensions.w,
+ input_def.dimensions.c);
+ data_type_ = input_def.object_def.data_type;
+ queue_ = environment->queue();
+ region_ = CalculateTextureRegion(output_def);
+ return absl::OkStatus();
+ }
+
+ absl::Status Convert(const TensorObject &input_obj, const TensorObject &output_obj) override
+ {
+ auto texture_input = absl::get_if<OpenClTexture>(&input_obj);
+ auto texture_output = absl::get_if<OpenClTexture>(&output_obj);
+ if (texture_input && texture_output)
+ {
+ return Copy(*texture_input, *texture_output);
+ }
+ auto buffer_input = absl::get_if<OpenClBuffer>(&input_obj);
+ auto buffer_output = absl::get_if<OpenClBuffer>(&output_obj);
+ if (buffer_input && buffer_output)
+ {
+ return Copy(*buffer_input, *buffer_output);
+ }
+ return absl::InternalError("Unexpected object");
+ }
+
+ absl::Status Copy(const OpenClBuffer &input, const OpenClBuffer &output)
+ {
+ if (input.memobj == output.memobj)
+ {
+ return absl::OkStatus();
+ }
+ return GetOpenCLError(clEnqueueCopyBuffer(queue_->queue(), input.memobj, output.memobj, 0, 0,
+ SizeOf(data_type_) * shape_.w * shape_.h *
+ AlignByN(shape_.c, 4) * shape_.b,
+ 0, nullptr, nullptr));
+ }
+
+ absl::Status Copy(const OpenClTexture &input, const OpenClTexture &output)
+ {
+ if (input.memobj == output.memobj)
+ {
+ return absl::OkStatus();
+ }
+ size_t origin[3] = {0, 0, 0};
+ return GetOpenCLError(clEnqueueCopyImage(queue_->queue(), input.memobj, output.memobj, origin,
+ origin, region_.data(), 0, nullptr, nullptr));
+ }
+
+private:
+ DataType data_type_ = DataType::UNKNOWN;
+ std::array<size_t, 3> region_;
+};
+
+// Copies data from/to CPU into a tensor.
+class CpuCopier : public OpenClConverterImpl
+{
+public:
+ static bool IsSupported(const ObjectDef &input, const ObjectDef &output)
+ {
+ return input.data_type == output.data_type && input.data_layout == output.data_layout &&
+ ((input.object_type == ObjectType::CPU_MEMORY &&
+ IsOpenClTextureOrBuffer(output.object_type)) ||
+ (output.object_type == ObjectType::CPU_MEMORY &&
+ IsOpenClTextureOrBuffer(input.object_type)));
+ }
+
+ absl::Status Init(const TensorObjectDef &input_def, const TensorObjectDef &output_def,
+ Environment *environment) final
+ {
+
+ region_ = CalculateTextureRegion(
+ input_def.object_def.object_type == ObjectType::CPU_MEMORY ? output_def : input_def);
+ queue_ = environment->queue();
+ return absl::OkStatus();
+ }
+
+ absl::Status Convert(const TensorObject &input_obj, const TensorObject &output_obj) override
+ {
+ auto cpu_input = absl::get_if<CpuMemory>(&input_obj);
+ auto cpu_output = absl::get_if<CpuMemory>(&output_obj);
+
+ if (cpu_input)
+ {
+ auto texture_output = absl::get_if<OpenClTexture>(&output_obj);
+ if (texture_output)
+ {
+ return queue_->EnqueueWriteImage(texture_output->memobj,
+ int3(region_[0], region_[1], region_[2]), cpu_input->data);
+ }
+ auto buffer_output = absl::get_if<OpenClBuffer>(&output_obj);
+ if (buffer_output)
+ {
+ return queue_->EnqueueWriteBuffer(buffer_output->memobj, cpu_input->size_bytes,
+ cpu_input->data);
+ }
+ }
+ else if (cpu_output)
+ {
+ auto texture_input = absl::get_if<OpenClTexture>(&input_obj);
+ if (texture_input)
+ {
+ return queue_->EnqueueReadImage(texture_input->memobj,
+ int3(region_[0], region_[1], region_[2]), cpu_output->data);
+ }
+ auto buffer_input = absl::get_if<OpenClBuffer>(&input_obj);
+ if (buffer_input)
+ {
+ return queue_->EnqueueReadBuffer(buffer_input->memobj, cpu_output->size_bytes,
+ cpu_output->data);
+ }
+ }
+ return absl::InternalError("Unexpected object");
+ }
+
+private:
+ std::array<size_t, 3> region_;
+};
+
+class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder
+{
+public:
+ explicit OpenClTensorConverterBuilder(Environment *environment) : environment_(environment) {}
+
+ bool IsSupported(const TensorObjectDef &input, const TensorObjectDef &output) const final
+ {
+ const auto &input_def = input.object_def;
+ const auto &output_def = output.object_def;
+ return input.dimensions == output.dimensions &&
+ (TrivialCopier::IsSupported(input_def, output_def) ||
+ TensorToTensorConverter::IsSupported(input_def, output_def) ||
+ CpuCopier::IsSupported(input_def, output_def) ||
+ TensorToBHWCBufferConverter::IsSupported(input_def, output_def) ||
+ BHWCBufferToTensorConverter::IsSupported(input_def, output_def));
+ }
+
+ absl::Status MakeConverter(const TensorObjectDef &input, const TensorObjectDef &output,
+ std::unique_ptr<TensorObjectConverter> *converter) final
+ {
+ std::unique_ptr<OpenClConverterImpl> impl;
+ const auto &input_def = input.object_def;
+ const auto &output_def = output.object_def;
+ if (TrivialCopier::IsSupported(input_def, output_def))
+ {
+ impl = absl::make_unique<TrivialCopier>();
+ }
+ else if (TensorToTensorConverter::IsSupported(input_def, output_def))
+ {
+ impl = absl::make_unique<TensorToTensorConverter>();
+ }
+ else if (CpuCopier::IsSupported(input_def, output_def))
+ {
+ impl = absl::make_unique<CpuCopier>();
+ }
+ else if (TensorToBHWCBufferConverter::IsSupported(input_def, output_def))
+ {
+ impl = absl::make_unique<TensorToBHWCBufferConverter>();
+ }
+ else if (BHWCBufferToTensorConverter::IsSupported(input_def, output_def))
+ {
+ impl = absl::make_unique<BHWCBufferToTensorConverter>();
+ }
+ else
+ {
+ return absl::UnimplementedError("Unsupported conversion");
+ }
+ RETURN_IF_ERROR(impl->Init(input, output, environment_));
+ *converter = std::move(impl);
+ return absl::OkStatus();
+ }
+
+ Environment *environment_;
+};
+
+} // namespace
+
+std::unique_ptr<TensorObjectConverterBuilder> NewConverterBuilder(Environment *environment)
+{
+ return absl::make_unique<OpenClTensorConverterBuilder>(environment);
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONVERTER_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONVERTER_H__
+
+#include <memory>
+
+#include "open_cl/Environment.h"
+#include "open_cl/Spi.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+// Supports conversions from BHWC to internal OpenCL tensor representation and
+// back. Also supports F16/F32.
+std::unique_ptr<TensorObjectConverterBuilder> NewConverterBuilder(Environment *environment);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_CONVERTER_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DepthwiseConv.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "open_cl/ClDevice.h"
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/LinearStorage.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+bool IsSpecializedCase(int channel_multiplier)
+{
+ return channel_multiplier == 1 || channel_multiplier == 2 || channel_multiplier == 4;
+}
+
+std::string GetSrcValue(int channel_multiplier, const std::string coords)
+{
+ std::string c;
+ if (channel_multiplier == 1)
+ {
+ c += " FLT4 src_final = args.src_tensor.Read(" + coords + ", S);\n";
+ }
+ else if (channel_multiplier == 2)
+ {
+ c += " int s_layer = S / 2;\n";
+ c += " FLT4 src = args.src_tensor.Read(" + coords + ", s_layer);\n";
+ c += " FLT2 t0 = S % 2 == 0 ? src.xy : src.zw;\n";
+ c += " FLT4 src_final = (FLT4)(t0.x, t0.x, t0.y, t0.y);\n";
+ }
+ else if (channel_multiplier == 4)
+ {
+ c += " int s_layer = S / 4;\n";
+ c += " FLT4 src = args.src_tensor.Read(" + coords + ", s_layer);\n";
+ c += " FLT t0 = src.x;\n";
+ c += " int reminder = S % 4;\n";
+ c += " if (reminder == 1) t0 = src.y;\n";
+ c += " if (reminder == 2) t0 = src.z;\n";
+ c += " if (reminder == 3) t0 = src.w;\n";
+ c += " FLT4 src_final = (FLT4)(t0, t0, t0, t0);\n";
+ }
+ else
+ {
+ c += " int s_layer = S / args.ch_multiplier;\n";
+ c += " FLT4 src = args.src_tensor.Read(" + coords + ", s_layer);\n";
+ c += " int s_offset = (S % args.ch_multiplier) * 4;\n";
+ c += " FLT4 src_final;\n";
+ c += " FLT temp_arr[4] = {src.x, src.y, src.z, src.w};\n";
+ c += " src_final.x = temp_arr[(s_offset + 0) / args.ch_multiplier];\n";
+ c += " src_final.y = temp_arr[(s_offset + 1) / args.ch_multiplier];\n";
+ c += " src_final.z = temp_arr[(s_offset + 2) / args.ch_multiplier];\n";
+ c += " src_final.w = temp_arr[(s_offset + 3) / args.ch_multiplier];\n";
+ }
+
+ return c;
+}
+
+std::string GenerateDepthwiseConvolutionCode(const OperationDef &op_def, bool stride_correction,
+ int channel_multiplier, bool weights_are_buffer,
+ bool dynamic_weights, GPUOperation *op)
+{
+ auto src_desc = op_def.src_tensors[0];
+ src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddSrcTensor("src_tensor", src_desc);
+ if (dynamic_weights)
+ {
+ op->AddSrcTensor("weights", op_def.src_tensors[1]);
+ }
+
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddDstTensor("dst_tensor", dst_desc);
+
+ const auto src_tensor_type = op_def.src_tensors[0].storage_type;
+
+ std::string c = GetCommonDefines(op_def.precision);
+
+ const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
+ src_tensor_type == TensorStorageType::IMAGE_BUFFER;
+
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int linear_id_1 = get_global_id(1);\n";
+ c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
+ c += " int Z = linear_id_1 % args.dst_tensor.Depth();\n";
+ }
+ else
+ {
+ c += " int Y = get_global_id(1);\n";
+ }
+ c += " int S = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "S >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ c += " ACCUM_FLT4 r = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ if (stride_correction)
+ {
+ c += " int x_offseted = " +
+ GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", "args.padding_x") +
+ ";\n";
+ }
+ else
+ {
+ if (op_def.IsBatchSupported())
+ {
+ c += " int x_offseted = X * args.stride_x + args.padding_x * "
+ "args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int x_offseted = X * args.stride_x + args.padding_x;\n";
+ }
+ }
+ c += " int y_offseted = Y * args.stride_y + args.padding_y;\n";
+ if (!dynamic_weights)
+ {
+ std::string weights_offset = "args.kernel_size_x * args.kernel_size_y";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int z_offseted = Z * args.stride_z + args.padding_z;\n";
+ weights_offset += " * args.kernel_size_z";
+ }
+ if (weights_are_buffer)
+ {
+ c += " int fx_c = S * " + weights_offset + ";\n";
+ }
+ else
+ {
+ c += " int fx_c = 0;\n";
+ }
+ }
+ std::string kernel_size_x = dynamic_weights ? "args.weights.Width()" : "args.kernel_size_x";
+ std::string kernel_size_y = dynamic_weights ? "args.weights.Height()" : "args.kernel_size_y";
+ std::string kernel_size_z = dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z";
+
+ std::string flat_coords = "x_c, y_c";
+ if (manual_clamp)
+ {
+ std::string check = "!outside_x && !outside_y";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ check += " && !outside_z";
+ flat_coords += ", z_c";
+ c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
+ c += " int z_c = z_offseted + kz * args.dilation_z;\n";
+ c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n";
+ }
+ c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
+ c += " int y_c = y_offseted + ky * args.dilation_y;\n";
+ c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
+ c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
+ const std::string dilation_x =
+ op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" : "args.dilation_x";
+ c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
+ c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
+ c += " if (" + check + ") {\n";
+ if (dynamic_weights)
+ {
+ c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
+ }
+ else
+ {
+ if (weights_are_buffer)
+ {
+ c += " FLT4 f = args.weights.Read(fx_c);\n";
+ }
+ else
+ {
+ c += " FLT4 f = args.weights.Read(fx_c, S);\n";
+ }
+ }
+ c += GetSrcValue(channel_multiplier, flat_coords);
+ c += " r += TO_ACCUM_TYPE(src_final * f);\n";
+ c += " };\n";
+ if (!dynamic_weights)
+ {
+ c += " fx_c++;\n";
+ }
+ c += " }\n";
+ c += " }\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " }\n";
+ }
+ }
+ else
+ { // Texture types with ZERO clamping
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ flat_coords += ", z_c";
+ c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
+ c += " int z_c = z_offseted + kz * args.dilation_z;\n";
+ if (src_tensor_type != TensorStorageType::TEXTURE_3D)
+ { // Only TEXTURE_3D supports clamping
+ // in DEPTH dimension
+ c += " if (z_c < 0 || z_c >= args.src_tensor.Depth()) {\n";
+ c += " fx_c += args.kernel_size_y * args.kernel_size_x;\n";
+ c += " continue;\n";
+ c += " }\n";
+ }
+ }
+ c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
+ c += " int y_c = y_offseted + ky * args.dilation_y;\n";
+ c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
+ const std::string dilation_x =
+ op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" : "args.dilation_x";
+ c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
+ c += GetSrcValue(channel_multiplier, flat_coords);
+ if (dynamic_weights)
+ {
+ c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
+ }
+ else
+ {
+ if (weights_are_buffer)
+ {
+ c += " FLT4 f = args.weights.Read(fx_c);\n";
+ }
+ else
+ {
+ c += " FLT4 f = args.weights.Read(fx_c, S);\n";
+ }
+ c += " fx_c++;\n";
+ }
+ c += " r += TO_ACCUM_TYPE(src_final * f);\n";
+ c += " }\n";
+ c += " }\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " }\n";
+ }
+ }
+ c += " FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " args.dst_tensor.Write(res0, X, Y, Z, S);\n";
+ }
+ else
+ {
+ c += " args.dst_tensor.Write(res0, X, Y, S);\n";
+ }
+ c += "}\n";
+
+ return c;
+}
+} // namespace
+
+GPUOperation CreateDepthwiseConvolution2D(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr)
+{
+ bool weights_are_buffer = device_info.IsMali();
+ GPUOperation op(definition);
+ op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("dilation_x", attr.dilations.w);
+ op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
+ op.args_.AddInt("stride_y", attr.strides.h);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("dilation_y", attr.dilations.h);
+ if (!IsSpecializedCase(attr.weights.shape.o))
+ {
+ op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
+ }
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, attr.weights.shape.o,
+ weights_are_buffer, false, &op);
+ UploadWeightsForDWConv2D(attr.weights, weights_are_buffer, definition.precision, &op);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+
+ TensorLinearDescriptor desc;
+ desc.storage_type =
+ weights_are_buffer ? LinearStorageType::BUFFER : LinearStorageType::TEXTURE_2D;
+ desc.element_type = definition.GetDataType();
+ desc.UploadLinearData(attr.bias);
+ op.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+ return op;
+}
+
+GPUOperation
+CreateDepthwiseConvolution2DDynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr)
+{
+ GPUOperation op(definition);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("dilation_x", attr.dilations.w);
+ op.args_.AddInt("stride_y", attr.strides.h);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("dilation_y", attr.dilations.h);
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, 1, false, true, &op);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+
+ TensorLinearDescriptor desc;
+ desc.storage_type =
+ device_info.IsMali() ? LinearStorageType::BUFFER : LinearStorageType::TEXTURE_2D;
+ desc.element_type = definition.GetDataType();
+ desc.UploadLinearData(attr.bias);
+ op.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+ return op;
+}
+
+GPUOperation CreateDepthwiseConvolution3D(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution3DAttributes &attr)
+{
+ bool weights_are_buffer = device_info.IsMali();
+ GPUOperation op(definition);
+ op.args_.AddInt("kernel_size_x", attr.weights.shape.w);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("dilation_x", attr.dilations.w);
+ op.args_.AddInt("kernel_size_y", attr.weights.shape.h);
+ op.args_.AddInt("stride_y", attr.strides.h);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("dilation_y", attr.dilations.h);
+ op.args_.AddInt("kernel_size_z", attr.weights.shape.d);
+ op.args_.AddInt("stride_z", attr.strides.d);
+ op.args_.AddInt("padding_z", -attr.padding.prepended.d);
+ op.args_.AddInt("dilation_z", attr.dilations.d);
+ if (!IsSpecializedCase(attr.weights.shape.o))
+ {
+ op.args_.AddInt("ch_multiplier", attr.weights.shape.o);
+ }
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, attr.weights.shape.o,
+ weights_are_buffer, false, &op);
+ UploadWeightsForDWConv3D(attr.weights, weights_are_buffer, definition.precision, &op);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+
+ TensorLinearDescriptor desc;
+ desc.storage_type =
+ weights_are_buffer ? LinearStorageType::BUFFER : LinearStorageType::TEXTURE_2D;
+ desc.element_type = definition.GetDataType();
+ desc.UploadLinearData(attr.bias);
+ op.args_.AddObject("biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_H__
+
+#include <vector>
+
+#include "open_cl/Buffer.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/LinearStorage.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Texture2d.h"
+#include "open_cl/Util.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+template <DataType S, typename T>
+void RearrangeWeightsForDWConv2D(const InternalTensor<OHWI, S> &weights, absl::Span<T> dst)
+{
+ const int dst_channels = weights.shape.i * weights.shape.o;
+ const int dst_depth = DivideRoundUp(dst_channels, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+
+ int counter = 0;
+ for (int d = 0; d < dst_depth; ++d)
+ {
+ for (int y = 0; y < kernel_y; ++y)
+ {
+ for (int x = 0; x < kernel_x; ++x)
+ {
+ T filter_val;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int d_ch = d * 4 + i;
+ if (d_ch < dst_channels)
+ {
+ const int f_index =
+ weights.shape.LinearIndex({d_ch % weights.shape.o, y, x, d_ch / weights.shape.o});
+ filter_val[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter_val[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter_val;
+ }
+ }
+ }
+}
+
+template <DataType T>
+void UploadWeightsForDWConv2D(const InternalTensor<OHWI, T> &weights, bool weights_are_buffer,
+ CalculationsPrecision precision, GPUOperation *op)
+{
+ const int dst_channels = weights.shape.i * weights.shape.o;
+ const int dst_slices = DivideRoundUp(dst_channels, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+
+ const int elements_count = kernel_x * kernel_y * dst_slices;
+
+ const bool fp32_weights = precision == CalculationsPrecision::F32;
+ const int float4_size = fp32_weights ? 16 : 8;
+
+ std::vector<uint8_t> data(float4_size * elements_count);
+
+ if (fp32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(data.data());
+ RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count));
+ }
+ // TODO
+ // It doesn't support F16 yet. I will try to add it later.
+ //
+ // else {
+ // half4* ptr = reinterpret_cast<half4*>(data.data());
+ // RearrangeWeightsForDWConv2D(weights, absl::MakeSpan(ptr, elements_count));
+ // }
+
+ if (weights_are_buffer)
+ {
+ BufferDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 4;
+ desc.size = float4_size * elements_count;
+ desc.data = std::move(data);
+ op->args_.AddObject("weights", absl::make_unique<BufferDescriptor>(desc));
+ }
+ else
+ {
+ Texture2DDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.size = int2(kernel_x * kernel_y, dst_slices);
+ desc.data = std::move(data);
+ op->args_.AddObject("weights", absl::make_unique<Texture2DDescriptor>(desc));
+ }
+}
+
+template <DataType S, typename T>
+void RearrangeWeightsForDWConv3D(const InternalTensor<OHWDI, S> &weights, absl::Span<T> dst)
+{
+ const int dst_channels = weights.shape.i * weights.shape.o;
+ const int dst_slices = DivideRoundUp(dst_channels, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+ const int kernel_z = weights.shape.d;
+
+ int counter = 0;
+ for (int d = 0; d < dst_slices; ++d)
+ {
+ for (int z = 0; z < kernel_z; ++z)
+ {
+ for (int y = 0; y < kernel_y; ++y)
+ {
+ for (int x = 0; x < kernel_x; ++x)
+ {
+ T filter_val;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int d_ch = d * 4 + i;
+ if (d_ch < dst_channels)
+ {
+ const int f_index = weights.shape.LinearIndex(
+ {d_ch % weights.shape.o, y, x, z, d_ch / weights.shape.o});
+ filter_val[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter_val[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter_val;
+ }
+ }
+ }
+ }
+}
+
+template <DataType T>
+void UploadWeightsForDWConv3D(const InternalTensor<OHWDI, T> &weights, bool weights_are_buffer,
+ CalculationsPrecision precision, GPUOperation *op)
+{
+ const int dst_channels = weights.shape.i * weights.shape.o;
+ const int dst_slices = DivideRoundUp(dst_channels, 4);
+ const int kernel_x = weights.shape.w;
+ const int kernel_y = weights.shape.h;
+ const int kernel_z = weights.shape.d;
+
+ const int elements_count = kernel_x * kernel_y * kernel_z * dst_slices;
+
+ const bool fp32_weights = precision == CalculationsPrecision::F32;
+ const int float4_size = fp32_weights ? 16 : 8;
+
+ std::vector<uint8_t> data(float4_size * elements_count);
+
+ if (fp32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(data.data());
+ RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count));
+ }
+ // TODO
+ // It doesn't support F16 yet. I will try to add it later.
+ //
+ // else {
+ // half4* ptr = reinterpret_cast<half4*>(data.data());
+ // RearrangeWeightsForDWConv3D(weights, absl::MakeSpan(ptr, elements_count));
+ // }
+
+ if (weights_are_buffer)
+ {
+ BufferDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 4;
+ desc.size = float4_size * elements_count;
+ desc.data = std::move(data);
+ op->args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc)));
+ }
+ else
+ {
+ Texture2DDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.size = int2(kernel_x * kernel_y * kernel_z, dst_slices);
+ desc.data = std::move(data);
+ op->args_.AddObject("weights", absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+ }
+}
+
+GPUOperation CreateDepthwiseConvolution2D(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr);
+
+GPUOperation
+CreateDepthwiseConvolution2DDynamicWeights(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr);
+
+GPUOperation CreateDepthwiseConvolution3D(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution3DAttributes &attr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DepthwiseConv3x3.h"
+
+#include <string>
+#include <utility>
+
+#include "open_cl/kernels/Util.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/Precision.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+DepthwiseConv3x3::DepthwiseConv3x3(const OperationDef &definition, bool weights_are_buffer,
+ bool local_mem_uploads, const DeviceInfo &device_info)
+ : GPUOperation(definition), local_mem_uploads_(local_mem_uploads)
+{
+ work_group_size_ = int3(8, 4, 1);
+ code_ = GenerateDepthwiseConvCode(definition_, weights_are_buffer, local_mem_uploads_);
+
+ if (definition_.precision == CalculationsPrecision::F16 && device_info.IsPowerVR())
+ {
+ compiler_options_.push_back(CompilerOptions::POWERVR_FP16);
+ }
+}
+
+DepthwiseConv3x3::DepthwiseConv3x3(DepthwiseConv3x3 &&operation)
+ : GPUOperation(std::move(operation)), local_mem_uploads_(operation.local_mem_uploads_)
+{
+}
+
+DepthwiseConv3x3 &DepthwiseConv3x3::operator=(DepthwiseConv3x3 &&operation)
+{
+ if (this != &operation)
+ {
+ std::swap(local_mem_uploads_, operation.local_mem_uploads_);
+ GPUOperation::operator=(std::move(operation));
+ }
+ return *this;
+}
+
+std::string DepthwiseConv3x3::GenerateDepthwiseConvCode(const OperationDef &op_def,
+ bool weights_are_buffer,
+ bool local_mem_uploads)
+{
+ auto src_desc = op_def.src_tensors[0];
+ src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
+ AddSrcTensor("src_tensor", src_desc);
+ AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
+
+ const auto src_tensor_type = op_def.src_tensors[0].storage_type;
+
+ const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
+ src_tensor_type == TensorStorageType::IMAGE_BUFFER;
+
+ std::string c = GetCommonDefines(op_def.precision);
+ if (local_mem_uploads)
+ {
+ c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
+ }
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int linear_id = get_global_id(0);\n";
+ c += " int X = (linear_id / args.dst_tensor.Batch()) * 2;\n";
+ c += " int B = linear_id % args.dst_tensor.Batch();\n";
+ c += " args.dst_tensor.SetBatchRef(B);\n";
+ c += " args.src_tensor.SetBatchRef(B);\n";
+ }
+ else
+ {
+ c += " int X = get_global_id(0) * 2;\n";
+ }
+ c += " int Y = get_global_id(1) * 2;\n";
+ c += " int S = get_global_id(2);\n";
+ c += " ACCUM_FLT4 r0 = (ACCUM_FLT4)(0.0f);\n";
+ c += " ACCUM_FLT4 r1 = (ACCUM_FLT4)(0.0f);\n";
+ c += " ACCUM_FLT4 r2 = (ACCUM_FLT4)(0.0f);\n";
+ c += " ACCUM_FLT4 r3 = (ACCUM_FLT4)(0.0f);\n";
+ if (!local_mem_uploads)
+ {
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
+ "|| S >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ }
+ if (local_mem_uploads)
+ {
+ c += " __local FLT4 f[10];\n";
+ c += " event_t e = async_work_group_copy(f, args.weights.GetPtr() + S * "
+ "10, 10, 0);\n";
+ c += " wait_group_events(1, &e);\n";
+ }
+ else if (weights_are_buffer)
+ {
+ c += " __global FLT4* f = args.weights.GetPtr() + S * 10;\n";
+ }
+ c += " FLT4 s0;\n";
+ c += " FLT4 s1;\n";
+ c += " FLT4 s2;\n";
+ c += " FLT4 s3;\n";
+ std::string W[9] = {"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8"};
+ std::string bias = "bias";
+ std::string xc[4] = {"X - 1", "X", "X + 1", "X + 2"};
+ std::string yc[4] = {"Y - 1", "Y", "Y + 1", "Y + 2"};
+ if (!weights_are_buffer)
+ {
+ c += " FLT4 f0 = args.weights.Read(0, S);\n";
+ c += " FLT4 f1 = args.weights.Read(1, S);\n";
+ c += " FLT4 f2 = args.weights.Read(2, S);\n";
+ c += " FLT4 f3 = args.weights.Read(3, S);\n";
+ c += " FLT4 f4 = args.weights.Read(4, S);\n";
+ c += " FLT4 f5 = args.weights.Read(5, S);\n";
+ c += " FLT4 f6 = args.weights.Read(6, S);\n";
+ c += " FLT4 f7 = args.weights.Read(7, S);\n";
+ c += " FLT4 f8 = args.weights.Read(8, S);\n";
+ }
+ if (manual_clamp)
+ {
+ c += " int x0 = X - 1;\n";
+ c += " int x1 = X;\n";
+ c += " int x2 = X + 1;\n";
+ c += " int x3 = X + 2;\n";
+ c += " int y0 = Y - 1;\n";
+ c += " int y1 = Y;\n";
+ c += " int y2 = Y + 1;\n";
+ c += " int y3 = Y + 2;\n";
+ c += " bool x0_in = x0 >= 0 && x0 < args.dst_tensor.Width();\n";
+ c += " bool x1_in = x1 >= 0 && x1 < args.dst_tensor.Width();\n";
+ c += " bool x2_in = x2 >= 0 && x2 < args.dst_tensor.Width();\n";
+ c += " bool x3_in = x3 >= 0 && x3 < args.dst_tensor.Width();\n";
+ c += " bool y0_in = y0 >= 0 && y0 < args.dst_tensor.Height();\n";
+ c += " bool y1_in = y1 >= 0 && y1 < args.dst_tensor.Height();\n";
+ c += " bool y2_in = y2 >= 0 && y2 < args.dst_tensor.Height();\n";
+ c += " bool y3_in = y3 >= 0 && y3 < args.dst_tensor.Height();\n";
+ c += " x0 = clamp(x0, 0, args.dst_tensor.Width() - 1);\n";
+ c += " x1 = clamp(x1, 0, args.dst_tensor.Width() - 1);\n";
+ c += " x2 = clamp(x2, 0, args.dst_tensor.Width() - 1);\n";
+ c += " x3 = clamp(x3, 0, args.dst_tensor.Width() - 1);\n";
+ c += " y0 = clamp(y0, 0, args.dst_tensor.Height() - 1);\n";
+ c += " y1 = clamp(y1, 0, args.dst_tensor.Height() - 1);\n";
+ c += " y2 = clamp(y2, 0, args.dst_tensor.Height() - 1);\n";
+ c += " y3 = clamp(y3, 0, args.dst_tensor.Height() - 1);\n";
+ if (src_tensor_type == TensorStorageType::BUFFER)
+ {
+ c += " __global FLT4* src_loc = "
+ "args.src_tensor.GetPtrWithSliceOffset(S);\n";
+ }
+ xc[0] = "x0";
+ xc[1] = "x1";
+ xc[2] = "x2";
+ xc[3] = "x3";
+ yc[0] = "y0";
+ yc[1] = "y1";
+ yc[2] = "y2";
+ yc[3] = "y3";
+ }
+ if (local_mem_uploads || weights_are_buffer)
+ {
+ W[0] = "f[0]";
+ W[1] = "f[1]";
+ W[2] = "f[2]";
+ W[3] = "f[3]";
+ W[4] = "f[4]";
+ W[5] = "f[5]";
+ W[6] = "f[6]";
+ W[7] = "f[7]";
+ W[8] = "f[8]";
+ bias = "f[9]";
+ }
+ auto read_4x_line = [&](int y) {
+ if (src_tensor_type == TensorStorageType::BUFFER)
+ {
+ const std::string y_in = "y" + std::to_string(y) + "_in";
+ c += " s0 = src_loc[args.src_tensor.GetWHOffset(" + xc[0] + ", " + yc[y] +
+ ")] * (FLT)(x0_in && " + y_in + ");\n";
+ c += " s1 = src_loc[args.src_tensor.GetWHOffset(" + xc[1] + ", " + yc[y] +
+ ")] * (FLT)(x1_in && " + y_in + ");\n";
+ c += " s2 = src_loc[args.src_tensor.GetWHOffset(" + xc[2] + ", " + yc[y] +
+ ")] * (FLT)(x2_in && " + y_in + ");\n";
+ c += " s3 = src_loc[args.src_tensor.GetWHOffset(" + xc[3] + ", " + yc[y] +
+ ")] * (FLT)(x3_in && " + y_in + ");\n";
+ }
+ else if (src_tensor_type == TensorStorageType::IMAGE_BUFFER)
+ {
+ const std::string y_in = "y" + std::to_string(y) + "_in";
+ c += " s0 = args.src_tensor.Read(" + xc[0] + ", " + yc[y] + ", S) * (FLT)(x0_in && " +
+ y_in + ");\n";
+ c += " s1 = args.src_tensor.Read(" + xc[1] + ", " + yc[y] + ", S) * (FLT)(x1_in && " +
+ y_in + ");\n";
+ c += " s2 = args.src_tensor.Read(" + xc[2] + ", " + yc[y] + ", S) * (FLT)(x2_in && " +
+ y_in + ");\n";
+ c += " s3 = args.src_tensor.Read(" + xc[3] + ", " + yc[y] + ", S) * (FLT)(x3_in && " +
+ y_in + ");\n";
+ }
+ else
+ {
+ c += " s0 = args.src_tensor.Read(" + xc[0] + ", " + yc[y] + ", S);\n";
+ c += " s1 = args.src_tensor.Read(" + xc[1] + ", " + yc[y] + ", S);\n";
+ c += " s2 = args.src_tensor.Read(" + xc[2] + ", " + yc[y] + ", S);\n";
+ c += " s3 = args.src_tensor.Read(" + xc[3] + ", " + yc[y] + ", S);\n";
+ }
+ };
+ c += " {\n";
+ read_4x_line(0);
+ c += " r0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[0] + " * s1);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[1] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[2] + " * s3);\n";
+ c += " }\n";
+ c += " {\n";
+ read_4x_line(1);
+ c += " r0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[3] + " * s1);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[0] + " * s1);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[4] + " * s2);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[1] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[5] + " * s3);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[2] + " * s3);\n";
+ c += " }\n";
+ c += " {\n";
+ read_4x_line(2);
+ c += " r0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[6] + " * s1);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[3] + " * s1);\n";
+ c += " r0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[7] + " * s2);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[4] + " * s2);\n";
+ c += " r1 += TO_ACCUM_TYPE(" + W[8] + " * s3);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[5] + " * s3);\n";
+ c += " }\n";
+ c += " {\n";
+ read_4x_line(3);
+ c += " r2 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[6] + " * s1);\n";
+ c += " r2 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[7] + " * s2);\n";
+ c += " r3 += TO_ACCUM_TYPE(" + W[8] + " * s3);\n";
+ c += " }\n";
+ if (!weights_are_buffer)
+ {
+ c += " FLT4 bias = args.weights.Read(9, S);\n";
+ }
+ c += " r0 += TO_ACCUM_TYPE(" + bias + ");\n";
+ c += " r1 += TO_ACCUM_TYPE(" + bias + ");\n";
+ c += " r2 += TO_ACCUM_TYPE(" + bias + ");\n";
+ c += " r3 += TO_ACCUM_TYPE(" + bias + ");\n";
+ if (local_mem_uploads)
+ {
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
+ "|| S >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ }
+ c += " if(X + 0 < args.dst_tensor.Width() && Y + 0 < "
+ "args.dst_tensor.Height()) {\n";
+ c += " FLT4 result = TO_FLT4(r0);\n";
+ c += " args.dst_tensor.Write(result, X + 0, Y + 0, S)\n";
+ c += " }\n";
+ c += " if(X + 1 < args.dst_tensor.Width() && Y + 0 < "
+ "args.dst_tensor.Height()) {\n";
+ c += " FLT4 result = TO_FLT4(r1);\n";
+ c += " args.dst_tensor.Write(result, X + 1, Y + 0, S)\n";
+ c += " }\n";
+ c += " if(X + 0 < args.dst_tensor.Width() && Y + 1 < "
+ "args.dst_tensor.Height()) {\n";
+ c += " FLT4 result = TO_FLT4(r2);\n";
+ c += " args.dst_tensor.Write(result, X + 0, Y + 1, S)\n";
+ c += " }\n";
+ c += " if(X + 1 < args.dst_tensor.Width() && Y + 1 < "
+ "args.dst_tensor.Height()) {\n";
+ c += " FLT4 result = TO_FLT4(r3);\n";
+ c += " args.dst_tensor.Write(result, X + 1, Y + 1, S)\n";
+ c += " }\n";
+ c += "}\n";
+
+ return c;
+}
+
+int3 DepthwiseConv3x3::GetGridSize() const
+{
+ const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch();
+ const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
+ const int grid_z = dst_[0]->Slices();
+ return int3(grid_x, grid_y, grid_z);
+}
+
+void DepthwiseConv3x3::GetPossibleKernelWorkGroups(TuningType tuning_type,
+ const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const
+{
+ if (local_mem_uploads_)
+ {
+ work_groups->push_back(work_group_size_);
+ }
+ else
+ {
+ GetPossibleWorkGroups(tuning_type, device_info, kernel_info, grid_size_, work_groups);
+ }
+}
+
+bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes &attr)
+{
+ return attr.weights.shape.o == 1 && attr.dilations.w == 1 && attr.dilations.h == 1 &&
+ attr.weights.shape.w == 3 && attr.weights.shape.h == 3 && attr.strides.w == 1 &&
+ attr.strides.h == 1 && attr.padding.prepended.w == 1 && attr.padding.prepended.h == 1 &&
+ attr.padding.appended.w == 1 && attr.padding.appended.h == 1;
+}
+
+DepthwiseConv3x3 CreateDepthwiseConv3x3(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr)
+{
+ bool weights_are_buffer = device_info.IsPowerVR() || device_info.IsMali();
+ bool local_mem_uploads = weights_are_buffer && device_info.IsPowerVR();
+ DepthwiseConv3x3 result(definition, weights_are_buffer, local_mem_uploads, device_info);
+ result.UploadWeightsAndBiases(attr.weights, attr.bias, weights_are_buffer);
+ return result;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_3X3_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_3X3_H__
+
+#include <memory>
+#include <vector>
+
+#include "open_cl/Buffer.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Texture2d.h"
+#include "open_cl/Util.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class DepthwiseConv3x3 : public GPUOperation
+{
+public:
+ DepthwiseConv3x3() = default;
+ void GetPossibleKernelWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const override;
+ int3 GetGridSize() const override;
+
+ // Move only
+ DepthwiseConv3x3(DepthwiseConv3x3 &&operation);
+ DepthwiseConv3x3 &operator=(DepthwiseConv3x3 &&operation);
+ DepthwiseConv3x3(const DepthwiseConv3x3 &) = delete;
+ DepthwiseConv3x3 &operator=(const DepthwiseConv3x3 &) = delete;
+
+private:
+ explicit DepthwiseConv3x3(const OperationDef &definition, bool weights_are_buffer,
+ bool local_mem_uploads, const DeviceInfo &device_info);
+ template <DataType T>
+ void UploadWeightsAndBiases(const InternalTensor<OHWI, T> &weights,
+ const InternalTensor<Linear, T> &biases, bool weights_are_buffer);
+
+ friend DepthwiseConv3x3 CreateDepthwiseConv3x3(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr);
+
+ template <DataType S, typename T>
+ void RearrangeWeightsAndBiasesData(const InternalTensor<OHWI, S> &weights,
+ const InternalTensor<Linear, S> &biases, absl::Span<T> dst);
+
+ std::string GenerateDepthwiseConvCode(const OperationDef &op_def, bool weights_are_buffer,
+ bool local_mem_uploads);
+
+ bool local_mem_uploads_;
+};
+
+template <DataType T>
+void DepthwiseConv3x3::UploadWeightsAndBiases(const InternalTensor<OHWI, T> &weights,
+ const InternalTensor<Linear, T> &biases,
+ bool weights_are_buffer)
+{
+ const int src_depth = DivideRoundUp(weights.shape.i, 4);
+ int texture_width = 10; // 3x3 kernel + 1 bias
+ int texture_height = src_depth;
+ const int elements_count = texture_width * texture_height;
+ const bool fp32_weights = definition_.precision == CalculationsPrecision::F32;
+ const int float4_size = fp32_weights ? 16 : 8;
+
+ std::vector<uint8_t> data(float4_size * elements_count);
+ if (fp32_weights)
+ {
+ float4 *ptr = reinterpret_cast<float4 *>(data.data());
+ RearrangeWeightsAndBiasesData(weights, biases, absl::MakeSpan(ptr, elements_count));
+ }
+ // TODO
+ // It doesn't support F16 yet. I will try to add it later.
+ //
+ // else {
+ // half4* ptr = reinterpret_cast<half4*>(data.data());
+ // RearrangeWeightsAndBiasesData(weights, biases,
+ // absl::MakeSpan(ptr, elements_count));
+ // }
+
+ if (weights_are_buffer)
+ {
+ BufferDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.element_size = 4;
+ desc.size = float4_size * elements_count;
+ desc.data = std::move(data);
+ args_.AddObject("weights", absl::make_unique<BufferDescriptor>(std::move(desc)));
+ }
+ else
+ {
+ Texture2DDescriptor desc;
+ desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
+ desc.size = int2(texture_width, texture_height);
+ desc.data = std::move(data);
+ args_.AddObject("weights", absl::make_unique<Texture2DDescriptor>(std::move(desc)));
+ }
+}
+
+template <DataType S, typename T>
+void DepthwiseConv3x3::RearrangeWeightsAndBiasesData(const InternalTensor<OHWI, S> &weights,
+ const InternalTensor<Linear, S> &biases,
+ absl::Span<T> dst)
+{
+ const int src_depth = DivideRoundUp(weights.shape.i, 4);
+
+ int counter = 0;
+ for (int s = 0; s < src_depth; ++s)
+ {
+ for (int y = 0; y < 3; ++y)
+ {
+ for (int x = 0; x < 3; ++x)
+ {
+ T filter_val;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int s_ch = s * 4 + i;
+ if (s_ch < weights.shape.i)
+ {
+ const int f_index = weights.shape.LinearIndex({0, y, x, s_ch});
+ filter_val[i] = weights.data[f_index];
+ }
+ else
+ {
+ filter_val[i] = 0.0f;
+ }
+ }
+ dst[counter++] = filter_val;
+ }
+ }
+
+ T bias_val;
+ for (int i = 0; i < 4; ++i)
+ {
+ const int dst_ch = s * 4 + i;
+ bias_val[i] = dst_ch >= biases.shape.v ? 0.0f : biases.data[dst_ch];
+ }
+ dst[counter++] = bias_val;
+ }
+}
+
+bool IsDepthwiseConv3x3Supported(const DepthwiseConvolution2DAttributes &attr);
+
+DepthwiseConv3x3 CreateDepthwiseConv3x3(const DeviceInfo &device_info,
+ const OperationDef &definition,
+ const DepthwiseConvolution2DAttributes &attr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_DEPTHWISE_CONV_3X3_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GpuOperation.h"
+
+#include "Util.h"
+#include "WorkGroupPicking.h"
+#include "open_cl/AccessType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::string GetElementWiseCode(const OperationDef &op_def, bool check_src_slices)
+{
+ std::string c = GetCommonDefines(op_def.precision);
+
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ c += " int Y = get_global_id(1);\n";
+ c += " int Z = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) return; \n";
+ if (check_src_slices)
+ {
+ c += " FLT4 src = (FLT4)(0.0f);\n";
+ c += " if (Z < args.src_tensor.Slices()) {\n";
+ c += " src = args.src_tensor.Read(X, Y, Z);\n";
+ c += " }\n";
+ }
+ else
+ {
+ c += " FLT4 src = args.src_tensor.Read(X, Y, Z);\n";
+ }
+ c += " args.dst_tensor.Write(src, X, Y, Z);\n";
+ c += "} \n";
+ return c;
+}
+
+int3 GetWorkGroupsCount(int grid_dimension, const int3 &grid_size, const int3 &work_group_size,
+ const int3 &work_group_launch_order)
+{
+ int3 work_groups_count;
+ if (grid_dimension == 1)
+ {
+ work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
+ work_groups_count.y = 1;
+ work_groups_count.z = 1;
+ }
+ else if (grid_dimension == 2)
+ {
+ int3 wgs;
+ wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
+ wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
+ work_groups_count.x = wgs[work_group_launch_order[0]];
+ work_groups_count.y = wgs[work_group_launch_order[1]];
+ work_groups_count.z = 1;
+ }
+ else
+ { // grid_dimension == 3
+ int3 wgs;
+ wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
+ wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
+ wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
+ work_groups_count.x = wgs[work_group_launch_order[0]];
+ work_groups_count.y = wgs[work_group_launch_order[1]];
+ work_groups_count.z = wgs[work_group_launch_order[2]];
+ }
+ return work_groups_count;
+}
+
+} // namespace
+
+DataType OperationDef::GetDataType() const { return DeduceDataTypeFromPrecision(precision); }
+
+DataType OperationDef::GetPrimaryDataType() const { return src_tensors[0].data_type; }
+TensorStorageType OperationDef::GetPrimaryStorageType() const
+{
+ return src_tensors[0].storage_type;
+}
+
+bool OperationDef::IsBatchSupported() const
+{
+ for (const auto &src : src_tensors)
+ {
+ if (HasAxis(src.layout, Axis::BATCH))
+ {
+ return true;
+ }
+ }
+ for (const auto &dst : dst_tensors)
+ {
+ if (HasAxis(dst.layout, Axis::BATCH))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+GPUOperation::GPUOperation(const OperationDef &definition) : definition_(definition) {}
+
+void GPUOperation::SetSrc(Tensor *ptr, int index)
+{
+ if (index >= (int)src_.size())
+ {
+ src_.resize(index + 1, nullptr);
+ }
+ src_[index] = ptr;
+}
+
+void GPUOperation::SetDst(Tensor *ptr, int index)
+{
+ if (index >= (int)dst_.size())
+ {
+ dst_.resize(index + 1, nullptr);
+ }
+ dst_[index] = ptr;
+}
+
+GPUOperation::GPUOperation(GPUOperation &&operation)
+ : args_(std::move(operation.args_)), code_(std::move(operation.code_)),
+ work_group_size_(operation.work_group_size_),
+ compiler_options_(std::move(operation.compiler_options_)),
+ tensor_to_grid_(operation.tensor_to_grid_), elementwise_(operation.elementwise_),
+ linkable_(operation.linkable_), check_src_channels_size_(operation.check_src_channels_size_),
+ definition_(std::move(operation.definition_)), src_(std::move(operation.src_)),
+ dst_(std::move(operation.dst_)), kernel_(std::move(operation.kernel_)),
+ grid_dimension_(operation.grid_dimension_),
+ work_group_launch_order_(operation.work_group_launch_order_), grid_size_(operation.grid_size_),
+ src_tensors_names_(std::move(operation.src_tensors_names_)),
+ dst_tensors_names_(std::move(operation.dst_tensors_names_)),
+ work_groups_count_(operation.work_groups_count_), linkable_count_(operation.linkable_count_),
+ elementwise_code_(std::move(operation.elementwise_code_))
+{
+}
+
+GPUOperation &GPUOperation::operator=(GPUOperation &&operation)
+{
+ if (this != &operation)
+ {
+ args_ = std::move(operation.args_);
+ code_ = std::move(operation.code_);
+ std::swap(work_group_size_, operation.work_group_size_);
+ compiler_options_ = std::move(operation.compiler_options_);
+ tensor_to_grid_ = operation.tensor_to_grid_;
+ elementwise_ = operation.elementwise_;
+ linkable_ = operation.linkable_;
+ check_src_channels_size_ = operation.check_src_channels_size_;
+ definition_ = std::move(operation.definition_);
+ src_ = std::move(operation.src_);
+ dst_ = std::move(operation.dst_);
+ kernel_ = std::move(operation.kernel_);
+ std::swap(grid_dimension_, operation.grid_dimension_);
+ std::swap(work_group_launch_order_, operation.work_group_launch_order_);
+ std::swap(grid_size_, operation.grid_size_);
+ src_tensors_names_ = std::move(operation.src_tensors_names_);
+ dst_tensors_names_ = std::move(operation.dst_tensors_names_);
+ std::swap(work_groups_count_, operation.work_groups_count_);
+ std::swap(linkable_count_, operation.linkable_count_);
+ elementwise_code_ = std::move(operation.elementwise_code_);
+ }
+ return *this;
+}
+
+absl::Status GPUOperation::AddOperation(GPUOperation *operation)
+{
+ linkable_count_ += 1;
+ std::string code = operation->code_;
+ std::string unique_postfix = absl::StrCat("_link", linkable_count_);
+ operation->args_.RenameArgs(unique_postfix, &code);
+ elementwise_code_ += "{\n" + code + "\n}\n";
+ RETURN_IF_ERROR(args_.Merge(std::move(operation->args_), unique_postfix));
+ for (size_t i = 0; i < operation->src_tensors_names_.size(); ++i)
+ {
+ definition_.src_tensors.push_back(operation->definition_.src_tensors[i + 1]);
+ src_tensors_names_.push_back(operation->src_tensors_names_[i] + unique_postfix);
+ }
+ for (size_t i = 0; i < operation->dst_tensors_names_.size(); ++i)
+ {
+ dst_tensors_names_.push_back(operation->dst_tensors_names_[i] + unique_postfix);
+ }
+ return absl::OkStatus();
+}
+
+void GPUOperation::AddSrcTensor(const std::string &tensor_name, const TensorDescriptor &desc)
+{
+ src_tensors_names_.push_back(tensor_name);
+ auto desc_new = std::make_unique<TensorDescriptor>(desc);
+ args_.AddObjectRef(tensor_name, AccessType::READ, std::move(desc_new));
+}
+
+void GPUOperation::AddSrcBuffer(const std::string &buffer_name, const BufferDescriptor &desc)
+{
+ src_tensors_names_.push_back(buffer_name);
+ auto desc_new = std::make_unique<BufferDescriptor>(desc);
+ args_.AddObjectRef(buffer_name, AccessType::READ, std::move(desc_new));
+}
+
+void GPUOperation::AddDstTensor(const std::string &tensor_name, const TensorDescriptor &desc)
+{
+ dst_tensors_names_.push_back(tensor_name);
+ auto desc_new = std::make_unique<TensorDescriptor>(desc);
+ args_.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
+}
+
+absl::Status GPUOperation::UpdateParams()
+{
+ for (size_t i = 0; i < src_tensors_names_.size(); ++i)
+ {
+ RETURN_IF_ERROR(args_.SetObjectRef(src_tensors_names_[i], src_[i]));
+ }
+ for (size_t i = 0; i < dst_tensors_names_.size(); ++i)
+ {
+ RETURN_IF_ERROR(args_.SetObjectRef(dst_tensors_names_[i], dst_[i]));
+ }
+ RETURN_IF_ERROR(BindArguments(&args_));
+ grid_size_ = GetGridSize();
+ work_groups_count_ =
+ GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_);
+ return absl::OkStatus();
+}
+
+absl::Status GPUOperation::AssembleCode(const DeviceInfo &device_info, CLContext *context)
+{
+ if (elementwise_)
+ {
+ auto src_desc = absl::make_unique<TensorDescriptor>(definition_.src_tensors[0]);
+ if (definition_.IsBatchSupported())
+ {
+ src_desc->SetStateVar("BatchedWidth", "true");
+ }
+ src_tensors_names_.insert(src_tensors_names_.begin(), "src_tensor");
+ args_.AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
+
+ auto dst_desc = absl::make_unique<TensorDescriptor>(definition_.dst_tensors[0]);
+ if (definition_.IsBatchSupported())
+ {
+ dst_desc->SetStateVar("BatchedWidth", "true");
+ }
+ dst_tensors_names_.insert(dst_tensors_names_.begin(), "dst_tensor");
+ args_.AddObjectRef("dst_tensor", AccessType::WRITE, std::move(dst_desc));
+
+ elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
+ code_ = GetElementWiseCode(definition_, check_src_channels_size_);
+ RETURN_IF_ERROR(args_.AllocateObjects(context));
+ RETURN_IF_ERROR(
+ args_.TransformToCLCode(device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
+ }
+ else
+ {
+ RETURN_IF_ERROR(args_.AllocateObjects(context));
+ RETURN_IF_ERROR(
+ args_.TransformToCLCode(device_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status GPUOperation::Compile(const CreationContext &creation_context)
+{
+ RETURN_IF_ERROR(AssembleCode(creation_context.GetDeviceInfo(), creation_context.context));
+ RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
+ code_, "main_function", compiler_options_, *creation_context.context, *creation_context.device,
+ &kernel_));
+ return PostCompileCheck(creation_context.device->info_, kernel_.info_);
+}
+
+absl::Status GPUOperation::CompileDeserialized(const CreationContext &creation_context)
+{
+ return creation_context.cache->GetOrCreateCLKernel(code_, "main_function", compiler_options_,
+ *creation_context.context,
+ *creation_context.device, &kernel_);
+}
+
+void GPUOperation::GetPossibleKernelWorkGroups(TuningType tuning_type,
+ const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const
+{
+ GetPossibleWorkGroups(tuning_type, device_info, kernel_info, grid_size_, work_groups);
+}
+
+absl::Status GPUOperation::Tune(const TuningParameters ¶ms)
+{
+ std::vector<int3> possible_work_groups;
+ GetPossibleKernelWorkGroups(params.tuning_type, *params.info, kernel_.info_,
+ &possible_work_groups);
+ if (possible_work_groups.empty())
+ {
+ return absl::NotFoundError("Can not found work_group size to launch kernel");
+ }
+ if (possible_work_groups.size() == 1)
+ {
+ work_group_size_ = possible_work_groups[0];
+ work_groups_count_ =
+ GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_);
+ return absl::OkStatus();
+ }
+ else
+ {
+ std::vector<int3> work_groups_count(possible_work_groups.size());
+ for (size_t i = 0; i < work_groups_count.size(); ++i)
+ {
+ work_groups_count[i] = GetWorkGroupsCount(grid_dimension_, grid_size_,
+ possible_work_groups[i], work_group_launch_order_);
+ }
+ RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
+ int best_work_group_index;
+ RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex(
+ kernel_, *params.info, work_groups_count, possible_work_groups, &best_work_group_index));
+ work_group_size_ = possible_work_groups[best_work_group_index];
+ work_groups_count_ =
+ GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_);
+ return absl::OkStatus();
+ }
+}
+
+int3 GPUOperation::GetGridSize() const
+{
+ if (elementwise_ || tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_SToZ)
+ {
+ const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
+ const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
+ const int grid_z = dst_[0]->Slices();
+ return int3(grid_x, grid_y, grid_z);
+ }
+ if (tensor_to_grid_ == TensorToGrid::kWBToX_HDToY_ZIs1)
+ {
+ const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
+ const int grid_y = dst_[0]->Height() * dst_[0]->Depth();
+ const int grid_z = 1;
+ return int3(grid_x, grid_y, grid_z);
+ }
+ if (tensor_to_grid_ == TensorToGrid::kWBToX_HToY_DToZ)
+ {
+ const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
+ const int grid_y = dst_[0]->Height();
+ const int grid_z = dst_[0]->Depth();
+ return int3(grid_x, grid_y, grid_z);
+ }
+ if (tensor_to_grid_ == TensorToGrid::kBToX_YIs1_ZIs1)
+ {
+ const int grid_x = dst_[0]->Batch();
+ const int grid_y = 1;
+ const int grid_z = 1;
+ return int3(grid_x, grid_y, grid_z);
+ }
+ return grid_size_;
+}
+
+void GPUOperation::AddUniquePostfix(const std::string &unique_postfix)
+{
+ for (uint32_t i = 0; i < src_tensors_names_.size(); ++i)
+ {
+ src_tensors_names_[i] += unique_postfix;
+ }
+ for (uint32_t i = 0; i < dst_tensors_names_.size(); ++i)
+ {
+ dst_tensors_names_[i] += unique_postfix;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_GPU_OPERATION_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_GPU_OPERATION_H__
+
+#include <string>
+#include <vector>
+
+#include "TuningParameters.h"
+
+#include "open_cl/Arguments.h"
+#include "open_cl/Buffer.h"
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/ClContext.h"
+#include "open_cl/ClDevice.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/ClProgram.h"
+#include "open_cl/DataType.h"
+#include "open_cl/DeviceInfo.h"
+#include "open_cl/Precision.h"
+#include "open_cl/ProgramCache.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/TensorType.h"
+#include "open_cl/Types.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// kCustom: default value
+// GPUOperation::GetGridSize must be overloaded
+// kWBToX_HDToY_SToZ:
+// grid_x = dst_[0]->Width() * dst_[0]->Batch();
+// grid_y = dst_[0]->Height() * dst_[0]->Depth();
+// grid_z = dst_[0]->Slices();
+// kWBToX_HDToY_ZIs1:
+// grid_x = dst_[0]->Width() * dst_[0]->Batch();
+// grid_y = dst_[0]->Height() * dst_[0]->Depth();
+// grid_z = 1;
+// kWBToX_HToY_DToZ:
+// grid_x = dst_[0]->Width() * dst_[0]->Batch();
+// grid_y = dst_[0]->Height();
+// grid_z = dst_[0]->Depth();
+// kBToX_YIs1_ZIs1:
+// grid_x = dst_[0]->Batch();
+// grid_y = 1;
+// grid_z = 1;
+enum class TensorToGrid
+{
+ kCustom,
+ kWBToX_HDToY_SToZ,
+ kWBToX_HDToY_ZIs1,
+ kWBToX_HToY_DToZ,
+ kBToX_YIs1_ZIs1
+};
+
+struct CreationContext
+{
+ const CLDevice *device;
+ CLContext *context;
+ CLCommandQueue *queue;
+ ProgramCache *cache;
+
+ const DeviceInfo &GetDeviceInfo() const { return device->info_; }
+};
+
+struct OperationDef
+{
+ CalculationsPrecision precision;
+ std::vector<TensorDescriptor> src_tensors;
+ std::vector<TensorDescriptor> dst_tensors;
+
+ // returns FLOAT32 for F32 precision and FLOAT16 for F16 precision
+ DataType GetDataType() const;
+ // Primary means the first src tensor, because first tensor usually defines
+ // the structure of kernel, all other resources(biases) types and etc.
+ DataType GetPrimaryDataType() const;
+ TensorStorageType GetPrimaryStorageType() const;
+ bool IsBatchSupported() const;
+};
+
+// GPUOperation represents some implementation of neural network operation on
+// GPU. GPUOperation can contain another GPU operations with flag elementwise_.
+// When GPUOperation contains another GPU ops, this GPUoperation replaces
+// some sequence of operations Op + op0 + op1 + ...
+// Because of this abilities of GPUOperation, usage scenario is next:
+// Create instance of GPUOperation.
+// Create all instances of GPUOperations that we will(probably) attach
+// to GPUOperation. Attach all GPUOperations to GPUOperation. Call
+// GPUOperation.Compile(). Don't call GPUOperations.Compile() if it
+// attached, it useless(and may be error)
+class GPUOperation
+{
+public:
+ GPUOperation() = default;
+ explicit GPUOperation(const OperationDef &definition);
+ virtual ~GPUOperation() = default;
+ // Move only
+ GPUOperation(GPUOperation &&operation);
+ GPUOperation &operator=(GPUOperation &&operation);
+ GPUOperation(const GPUOperation &) = delete;
+ GPUOperation &operator=(const GPUOperation &) = delete;
+
+ absl::Status AddOperation(GPUOperation *operation);
+
+ void SetSrc(Tensor *ptr, int index = 0);
+ void SetDst(Tensor *ptr, int index = 0);
+
+ // should be called after changes of inputs/outputs.
+ absl::Status UpdateParams();
+
+ absl::Status AddToQueue(CLCommandQueue *queue)
+ {
+ RETURN_IF_ERROR(args_.Bind(kernel_.kernel()));
+ return queue->Dispatch(kernel_, work_groups_count_, work_group_size_);
+ }
+
+ virtual void GetPossibleKernelWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info,
+ std::vector<int3> *work_groups) const;
+
+ absl::Status Tune(const TuningParameters ¶ms);
+
+ absl::Status AssembleCode(const DeviceInfo &device_info, CLContext *context);
+
+ absl::Status Compile(const CreationContext &creation_context);
+
+ absl::Status CompileDeserialized(const CreationContext &creation_context);
+
+ virtual absl::Status PostCompileCheck(const DeviceInfo &, const KernelInfo &)
+ {
+ return absl::OkStatus();
+ }
+
+ const OperationDef &GetDefinition() const { return definition_; }
+
+ void AddSrcTensor(const std::string &tensor_name, const TensorDescriptor &desc);
+ void AddSrcBuffer(const std::string &buffer_name, const BufferDescriptor &desc);
+ void AddDstTensor(const std::string &tensor_name, const TensorDescriptor &desc);
+
+ bool IsLinkable() const { return elementwise_ && linkable_; }
+
+ // for linking
+ void AddUniquePostfix(const std::string &unique_postfix);
+
+ Arguments args_;
+ std::string code_;
+ int3 work_group_size_ = int3(8, 4, 1);
+ std::vector<CompilerOptions> compiler_options_;
+ // not applicable to elementwise
+ TensorToGrid tensor_to_grid_ = TensorToGrid::kCustom;
+
+ bool elementwise_ = false;
+ // applicable only with elementwise_ = true;
+ bool linkable_ = true; // by default every elementwise is linkable
+ // applicable only with elementwise_ = true;
+ bool check_src_channels_size_ = false;
+
+protected:
+ virtual absl::Status BindArguments(ArgumentsBinder *) { return absl::OkStatus(); }
+ virtual int3 GetGridSize() const;
+
+ // Defines operation calculation precision and format of src/dst tensors.
+ OperationDef definition_;
+ std::vector<Tensor *> src_;
+ std::vector<Tensor *> dst_;
+ CLKernel kernel_;
+ int grid_dimension_ = 3; // can be 1, 2 or 3
+ int3 work_group_launch_order_ = int3(0, 1, 2);
+ int3 grid_size_ = int3(0, 0, 0);
+ std::vector<std::string> src_tensors_names_;
+ std::vector<std::string> dst_tensors_names_;
+
+private:
+ int3 work_groups_count_ = int3(0, 0, 0);
+ int linkable_count_ = 0;
+ std::string elementwise_code_; // temporary, used during op construction
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_GPU_OPERATION_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Pooling.h"
+
+#include <string>
+
+#include "Util.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::string GetAveragePoolingKernelCode(const OperationDef &op_def, bool stride_correction,
+ GPUOperation *op)
+{
+ auto src_desc = op_def.src_tensors[0];
+
+ src_desc.SetTextureAddressMode(TextureAddressMode::ZERO);
+
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddSrcTensor("src_tensor", src_desc);
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddDstTensor("dst_tensor", dst_desc);
+
+ std::map<Axis, std::string> axis_to_src_coord = {
+ {Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
+ {Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
+ };
+
+ std::map<Axis, std::string> axis_to_dst_coord = {
+ {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
+ {Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
+ };
+
+ std::vector<std::string> src_coords;
+ std::vector<std::string> dst_coords;
+ for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS})
+ {
+ if (op_def.dst_tensors[0].HasAxis(axis))
+ {
+ dst_coords.push_back(axis_to_dst_coord[axis]);
+ }
+ if (op_def.src_tensors[0].HasAxis(axis))
+ {
+ src_coords.push_back(axis_to_src_coord[axis]);
+ }
+ }
+ std::string src_coord = src_coords[0];
+ for (size_t i = 1; i < src_coords.size(); ++i)
+ {
+ src_coord += ", " + src_coords[i];
+ }
+ std::string dst_coord = dst_coords[0];
+ for (size_t i = 1; i < dst_coords.size(); ++i)
+ {
+ dst_coord += ", " + dst_coords[i];
+ }
+
+ const bool manual_clamp = op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER ||
+ op_def.src_tensors[0].storage_type == TensorStorageType::IMAGE_BUFFER;
+
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int linear_id_1 = get_global_id(1);\n";
+ c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
+ c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
+ }
+ else
+ {
+ c += " int Y = get_global_id(1);\n";
+ }
+ c += " int Z = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ c += " float4 r = (float4)(0.0f);\n";
+ c += " float window_size = 0.0;\n";
+ if (stride_correction)
+ {
+ c += " int xs = " +
+ GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", "args.padding_x") +
+ ";\n";
+ }
+ else
+ {
+ if (op_def.IsBatchSupported())
+ {
+ c += " int xs = X * args.stride_x + args.padding_x * "
+ "args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int xs = X * args.stride_x + args.padding_x;\n";
+ }
+ }
+ c += " int ys = Y * args.stride_y + args.padding_y;\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int ds = D * args.stride_z + args.padding_z;\n";
+ c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
+ c += " int d_c = ds + kz;\n";
+ c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
+ }
+ c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
+ c += " int y_c = ys + ky;\n";
+ c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
+ c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
+ if (op_def.IsBatchSupported())
+ {
+ c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int x_c = xs + kx;\n";
+ }
+ c += " bool outside = outside_y || x_c < 0 || x_c >= "
+ "args.src_tensor.Width();\n";
+ if (manual_clamp)
+ {
+ c += " r += !outside ? args.src_tensor.Read<float>(" + src_coord +
+ ") : "
+ "(float4)(0.0f);\n";
+ }
+ else
+ {
+ c += " r += args.src_tensor.Read<float>(" + src_coord + ");\n";
+ }
+ c += " window_size += !outside ? 1.0 : 0.0;\n";
+ c += " }\n";
+ c += " }\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " } // Depth\n";
+ }
+ // If window_size==0, window covered nothing. This situation is a sign of
+ // incorrectly constructed operation. NaNs are expected as output.
+ c += " FLT4 result = TO_FLT4(r / window_size);\n";
+ c += " args.dst_tensor.Write(result, " + dst_coord + ");\n";
+ c += "}\n";
+
+ return c;
+}
+
+std::string GetMaxPoolingKernelCode(const OperationDef &op_def, bool stride_correction,
+ bool output_indices, GPUOperation *op)
+{
+ auto src_desc = op_def.src_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddSrcTensor("src_tensor", src_desc);
+ auto dst_desc = op_def.dst_tensors[0];
+ if (op_def.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddDstTensor("dst_tensor", dst_desc);
+ if (output_indices)
+ {
+ auto dst_ind_desc = op_def.dst_tensors[1];
+ if (op_def.IsBatchSupported())
+ {
+ dst_ind_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op->AddDstTensor("dst_indices", dst_ind_desc);
+ }
+
+ std::map<Axis, std::string> axis_to_src_coord = {
+ {Axis::WIDTH, "x_c"}, {Axis::HEIGHT, "y_c"}, {Axis::DEPTH, "d_c"},
+ {Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
+ };
+
+ std::map<Axis, std::string> axis_to_dst_coord = {
+ {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
+ {Axis::CHANNELS, "Z"}, {Axis::BATCH, "B"},
+ };
+
+ std::vector<std::string> src_coords;
+ std::vector<std::string> dst_coords;
+ for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS})
+ {
+ if (op_def.dst_tensors[0].HasAxis(axis))
+ {
+ dst_coords.push_back(axis_to_dst_coord[axis]);
+ }
+ if (op_def.src_tensors[0].HasAxis(axis))
+ {
+ src_coords.push_back(axis_to_src_coord[axis]);
+ }
+ }
+ std::string src_coord = src_coords[0];
+ for (size_t i = 1; i < src_coords.size(); ++i)
+ {
+ src_coord += ", " + src_coords[i];
+ }
+ std::string dst_coord = dst_coords[0];
+ for (size_t i = 1; i < dst_coords.size(); ++i)
+ {
+ dst_coord += ", " + dst_coords[i];
+ }
+
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int linear_id_1 = get_global_id(1);\n";
+ c += " int Y = linear_id_1 / args.dst_tensor.Depth();\n";
+ c += " int D = linear_id_1 % args.dst_tensor.Depth();\n";
+ }
+ else
+ {
+ c += " int Y = get_global_id(1);\n";
+ }
+ c += " int Z = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ c += " FLT4 maximum = (FLT4)(-10000.0f);\n";
+ if (output_indices)
+ {
+ c += " FLT4 indexes = (FLT4)(0.0f);\n";
+ }
+ if (stride_correction)
+ {
+ c += " int xs = " +
+ GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x", "args.padding_x") +
+ ";\n";
+ }
+ else
+ {
+ if (op_def.IsBatchSupported())
+ {
+ c += " int xs = X * args.stride_x + args.padding_x * "
+ "args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int xs = X * args.stride_x + args.padding_x;\n";
+ }
+ }
+ c += " int ys = Y * args.stride_y + args.padding_y;\n";
+ c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
+ c += " int y_c = ys + ky;\n";
+ c += " if (y_c < 0 || y_c >= args.src_tensor.Height()) continue;\n";
+ c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
+ if (op_def.IsBatchSupported())
+ {
+ c += " int x_c = xs + kx * args.src_tensor.Batch();\n";
+ }
+ else
+ {
+ c += " int x_c = xs + kx;\n";
+ }
+ c += " if (x_c < 0 || x_c >= args.src_tensor.Width()) continue;\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " int ds = D * args.stride_z + args.padding_z;\n";
+ c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
+ c += " int d_c = ds + kz;\n";
+ c += " if (d_c < 0 || d_c >= args.src_tensor.Depth()) continue;\n";
+ }
+ c += " FLT4 src = args.src_tensor.Read(" + src_coord + ");\n";
+ if (output_indices)
+ {
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " FLT index_counter = (FLT)((ky * args.kernel_size_x + kx) * "
+ "args.kernel_size_z + kz) + (FLT)(0.1f);\n";
+ }
+ else
+ {
+ c += " FLT index_counter = (FLT)(ky * args.kernel_size_x + kx) + "
+ "(FLT)(0.1f);\n";
+ }
+ c += " if (src.x > maximum.x) {\n";
+ c += " indexes.x = index_counter;\n";
+ c += " maximum.x = src.x;\n";
+ c += " }\n";
+ c += " if (src.y > maximum.y) {\n";
+ c += " indexes.y = index_counter;\n";
+ c += " maximum.y = src.y;\n";
+ c += " }\n";
+ c += " if (src.z > maximum.z) {\n";
+ c += " indexes.z = index_counter;\n";
+ c += " maximum.z = src.z;\n";
+ c += " }\n";
+ c += " if (src.w > maximum.w) {\n";
+ c += " indexes.w = index_counter;\n";
+ c += " maximum.w = src.w;\n";
+ c += " }\n";
+ }
+ else
+ {
+ c += " maximum = max(src, maximum);\n";
+ }
+ if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH))
+ {
+ c += " } // Depth\n";
+ }
+ c += " }\n";
+ c += " }\n";
+ c += " args.dst_tensor.Write(maximum, " + dst_coord + ");\n";
+ if (output_indices)
+ {
+ c += " args.dst_indices.Write(indexes, " + dst_coord + ");\n";
+ }
+ c += "}\n";
+
+ return c;
+}
+} // namespace
+
+GPUOperation CreatePooling(const OperationDef &definition, const Pooling2DAttributes &attr)
+{
+ GPUOperation op(definition);
+ op.args_.AddInt("kernel_size_x", attr.kernel.w);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("kernel_size_y", attr.kernel.h);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("stride_y", attr.strides.h);
+
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ if (attr.type == PoolingType::AVERAGE)
+ {
+ op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+ }
+ else if (attr.type == PoolingType::MAX)
+ {
+ op.code_ = GetMaxPoolingKernelCode(definition, stride_correction, attr.output_indices, &op);
+ }
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ return op;
+}
+
+GPUOperation CreatePooling(const OperationDef &definition, const Pooling3DAttributes &attr)
+{
+ GPUOperation op(definition);
+ op.args_.AddInt("kernel_size_x", attr.kernel.w);
+ op.args_.AddInt("padding_x", -attr.padding.prepended.w);
+ op.args_.AddInt("stride_x", attr.strides.w);
+ op.args_.AddInt("kernel_size_y", attr.kernel.h);
+ op.args_.AddInt("padding_y", -attr.padding.prepended.h);
+ op.args_.AddInt("stride_y", attr.strides.h);
+ op.args_.AddInt("kernel_size_z", attr.kernel.d);
+ op.args_.AddInt("padding_z", -attr.padding.prepended.d);
+ op.args_.AddInt("stride_z", attr.strides.d);
+ const bool stride_correction = definition.IsBatchSupported() && attr.strides.w != 1;
+ if (attr.type == PoolingType::AVERAGE)
+ {
+ op.code_ = GetAveragePoolingKernelCode(definition, stride_correction, &op);
+ }
+ else if (attr.type == PoolingType::MAX)
+ {
+ op.code_ = GetMaxPoolingKernelCode(definition, stride_correction, attr.output_indices, &op);
+ }
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_POOLING_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_POOLING_H__
+
+#include "GpuOperation.h"
+
+#include "open_cl/Operations.h"
+#include "open_cl/Precision.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreatePooling(const OperationDef &definition, const Pooling2DAttributes &attr);
+
+GPUOperation CreatePooling(const OperationDef &definition, const Pooling3DAttributes &attr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_ADD_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu.h"
+
+#include <string>
+#include "Util.h"
+#include "GpuOperation.h"
+#include "absl/strings/str_cat.h"
+#include "open_cl/Precision.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreateReLU(const OperationDef &definition, const ReLUAttributes &attr)
+{
+ GPUOperation op(definition);
+ op.elementwise_ = true;
+
+ std::string min_func;
+ if (attr.alpha != 0.0f)
+ {
+ min_func = "min(in_out_value * args.alpha, (FLT)(0.0f))";
+ if (definition.precision == CalculationsPrecision::F32)
+ {
+ op.args_.AddFloat("alpha", attr.alpha);
+ }
+ else
+ {
+#ifdef FIXME_PORTING_HALF_REQIRED
+ op.args_.AddHalf("alpha", half(attr.alpha));
+#endif
+ }
+ }
+ else
+ {
+ min_func = "(FLT)(0.0f)";
+ }
+ if (attr.clip != 0.0f)
+ {
+ if (definition.precision == CalculationsPrecision::F32)
+ {
+ op.args_.AddFloat("clip", attr.clip);
+ }
+ else
+ {
+#ifdef FIXME_PORTING_HALF_REQIRED
+ op.args_.AddHalf("clip", half(attr.clip));
+#endif
+ }
+ op.code_ = absl::StrCat("in_out_value = clamp(in_out_value, " + min_func + ", args.clip);");
+ }
+ else
+ {
+ op.code_ = absl::StrCat("in_out_value = max(in_out_value, ", min_func, ");");
+ }
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RELU_H__
+#define __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RELU_H__
+
+#include "open_cl/ClKernel.h"
+#include "GpuOperation.h"
+#include "open_cl/Precision.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+#include "open_cl/Operations.h"
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreateReLU(const OperationDef &definition, const ReLUAttributes &attr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RELU_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reshape.h"
+
+#include <string>
+
+#include "Util.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+std::string GetReshapeCode(const OperationDef &op_def)
+{
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int linear_id = get_global_id(0);\n";
+ c += " int X = linear_id / args.dst_tensor.Batch();\n";
+ c += " int B = linear_id % args.dst_tensor.Batch();\n";
+ c += " args.dst_tensor.SetBatchRef(B);\n";
+ }
+ else
+ {
+ c += " int X = get_global_id(0);\n";
+ }
+ c += " int Y = get_global_id(1);\n";
+ c += " int Z = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ c += " FLT temps[4];\n";
+ c += " temps[0] = (FLT)(0.0f);\n";
+ c += " temps[1] = (FLT)(0.0f);\n";
+ c += " temps[2] = (FLT)(0.0f);\n";
+ c += " temps[3] = (FLT)(0.0f);\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int base = B;\n";
+ }
+ else
+ {
+ c += " int base = 0;\n";
+ }
+ c += " base = ((base * args.dst_tensor.Height() + Y) * "
+ "args.dst_tensor.Width() + X) * args.dst_tensor.Channels() + Z * 4;\n";
+ c += " for (int i = 0; i < 4; ++i) {\n";
+ c += " int dst_channel = Z * 4 + i;\n";
+ c += " if (dst_channel < args.dst_tensor.Channels()) {;\n";
+ c += " int p = base + i;\n";
+ c += " int src_c = p % args.src_tensor.Channels();\n";
+ c += " p = p / args.src_tensor.Channels();\n";
+ c += " int src_x = p % args.src_tensor.Width();\n";
+ c += " p = p / args.src_tensor.Width();\n";
+ c += " int src_y = p % args.src_tensor.Height();\n";
+ if (op_def.src_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int src_b = p / args.src_tensor.Height();\n";
+ c += " args.src_tensor.SetBatchRef(src_b);\n";
+ }
+ c += " int src_z = src_c / 4;\n";
+ c += " int src_sub_ch = src_c % 4;\n";
+ c += " FLT4 t = args.src_tensor.Read(src_x, src_y, src_z);\n";
+ c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
+ c += " temps[i] = t_ar[src_sub_ch];\n";
+ c += " }\n";
+ c += " }\n";
+ c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n";
+ c += " args.dst_tensor.Write(result, X, Y, Z);\n";
+ c += "}\n";
+ return c;
+}
+
+} // namespace
+
+GPUOperation CreateReshape(const OperationDef &definition)
+{
+ GPUOperation op(definition);
+ op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ op.code_ = GetReshapeCode(definition);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_RESHAPE_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_RESHAPE_H__
+
+#include "GpuOperation.h"
+
+#include "open_cl/Operations.h"
+#include "open_cl/Precision.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreateReshape(const OperationDef &definition);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_RESHAPE_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reshape.h"
+
+#include <string>
+
+#include "Util.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::string GetReshapeCode(const OperationDef &op_def)
+{
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int linear_id = get_global_id(0);\n";
+ c += " int X = linear_id / args.dst_tensor.Batch();\n";
+ c += " int B = linear_id % args.dst_tensor.Batch();\n";
+ c += " args.dst_tensor.SetBatchRef(B);\n";
+ }
+ else
+ {
+ c += " int X = get_global_id(0);\n";
+ }
+ c += " int Y = get_global_id(1);\n";
+ c += " int Z = get_global_id(2);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
+ "Z >= args.dst_tensor.Slices()) { \n";
+ c += " return; \n";
+ c += " } \n";
+ if (op_def.dst_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int dst_bhwc4 = B;\n";
+ }
+ else
+ {
+ c += " int dst_bhwc4 = 0;\n";
+ }
+ c += " dst_bhwc4 = ((dst_bhwc4 * args.dst_tensor.Height() + Y) * "
+ "args.dst_tensor.Width() + X) * args.dst_tensor.Slices() + Z;\n";
+ c += " int src_z = dst_bhwc4 % args.src_tensor.Slices();\n";
+ c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Slices();\n";
+ c += " int src_x = dst_bhwc4 % args.src_tensor.Width();\n";
+ c += " dst_bhwc4 = dst_bhwc4 / args.src_tensor.Width();\n";
+ c += " int src_y = dst_bhwc4 % args.src_tensor.Height();\n";
+ if (op_def.src_tensors[0].HasAxis(Axis::BATCH))
+ {
+ c += " int src_b = dst_bhwc4 / args.src_tensor.Height();\n";
+ c += " args.src_tensor.SetBatchRef(src_b);\n";
+ }
+ c += " FLT4 result = args.src_tensor.Read(src_x, src_y, src_z);\n";
+ c += " args.dst_tensor.Write(result, X, Y, Z);\n";
+ c += "}\n";
+ return c;
+}
+
+} // namespace
+
+GPUOperation CreateReshapex4(const OperationDef &definition)
+{
+ GPUOperation op(definition);
+ op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
+ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
+ op.code_ = GetReshapeCode(definition);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RESHAPEX4_H__
+#define __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RESHAPEX4_H__
+
+#include "GpuOperation.h"
+
+#include "open_cl/Operations.h"
+#include "open_cl/Precision.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// More optimized, but require src_channels % 4 == 0 and dst_channels % 4 == 0
+GPUOperation CreateReshapex4(const OperationDef &definition);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_RESHAPEX4_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Softmax.h"
+
+#include <string>
+
+#include "Util.h"
+#include "WorkGroupPicking.h"
+#include "GpuOperation.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+namespace
+{
+std::string GetSoftmaxKernelCode(const OperationDef &op_def)
+{
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ c += " int X = get_global_id(0);\n";
+ c += " int Y = get_global_id(1);\n";
+ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
+ "return; \n";
+ c += " float sum = 0.0f;\n";
+ c += " for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n";
+ c += " float4 t = args.src_tensor.Read<float>(X, Y, d);\n";
+ c += " sum += exp(t.x);\n";
+ c += " if (d * 4 + 1 < args.dst_tensor.Channels()) sum += exp(t.y);\n";
+ c += " if (d * 4 + 2 < args.dst_tensor.Channels()) sum += exp(t.z);\n";
+ c += " if (d * 4 + 3 < args.dst_tensor.Channels()) sum += exp(t.w);\n";
+ c += " }\n";
+ c += " for (int d = 0; d < args.dst_tensor.Slices(); ++d) {\n";
+ c += " float4 t = args.src_tensor.Read<float>(X, Y, d);\n";
+ c += " t = exp(t) / sum;\n";
+ c += " FLT4 result = TO_FLT4(t);\n";
+ c += " args.dst_tensor.Write(result, X, Y, d);\n";
+ c += " }\n";
+ c += "}\n";
+ return c;
+}
+} // namespace
+
+GPUOperation CreateSoftmax(const OperationDef &definition)
+{
+ GPUOperation op(definition);
+ auto src_desc = definition.src_tensors[0];
+ if (definition.IsBatchSupported())
+ {
+ src_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op.AddSrcTensor("src_tensor", src_desc);
+ auto dst_desc = definition.dst_tensors[0];
+ if (definition.IsBatchSupported())
+ {
+ dst_desc.SetStateVar("BatchedWidth", "true");
+ }
+ op.AddDstTensor("dst_tensor", dst_desc);
+ op.code_ = GetSoftmaxKernelCode(definition);
+ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1;
+ return op;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX_H__
+#define __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX_H__
+
+#include "open_cl/ClKernel.h"
+#include "GpuOperation.h"
+#include "open_cl/Precision.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+GPUOperation CreateSoftmax(const OperationDef &definition);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Softmax1x1.h"
+
+#include <string>
+
+#include "Util.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+Softmax1x1::Softmax1x1(const OperationDef &definition) : GPUOperation(definition)
+{
+ work_group_size_ = int3(32, 1, 1);
+ code_ = GetSoftmaxKernelCode(definition_);
+}
+
+Softmax1x1::Softmax1x1(Softmax1x1 &&kernel) : GPUOperation(std::move(kernel)) {}
+
+Softmax1x1 &Softmax1x1::operator=(Softmax1x1 &&kernel)
+{
+ if (this != &kernel)
+ {
+ GPUOperation::operator=(std::move(kernel));
+ }
+ return *this;
+}
+
+std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef &op_def)
+{
+ AddSrcTensor("src_tensor", op_def.src_tensors[0]);
+ AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
+ args_.AddFloat("mask_x");
+ args_.AddFloat("mask_y");
+ args_.AddFloat("mask_z");
+ args_.AddFloat("mask_w");
+ args_.AddInt("slices_x32");
+
+ std::string c = GetCommonDefines(op_def.precision);
+ c += "__kernel void main_function(\n";
+ c += "$0) {\n";
+ if (op_def.IsBatchSupported())
+ {
+ c += " int batch_id = get_global_id(1);\n";
+ c += " if (batch_id >= args.dst_tensor.Batch()) return;\n";
+ c += " args.dst_tensor.SetBatchRef(batch_id);\n";
+ c += " args.src_tensor.SetBatchRef(batch_id);\n";
+ }
+ c += " float4 mask = (float4)(args.mask_x, args.mask_y, args.mask_z, "
+ "args.mask_w);\n";
+ c += " int offset = 0;\n";
+ c += " float sum = 0.0f;\n";
+ c += " int s = 0;\n";
+ c += " int tid = get_local_id(0);\n";
+ c += " do {\n";
+ c += " int z = offset + tid;\n";
+ c += " if (z < args.dst_tensor.Slices()) {\n";
+ c += " float4 mask_temp = z == args.dst_tensor.Slices() - 1 ? mask : "
+ "(float4)(1.0f);\n";
+ c += " float4 src = args.src_tensor.Read<float>(0, 0, z);\n";
+ c += " sum += dot(mask_temp, exp(src));\n";
+ c += " offset += 32;\n";
+ c += " }\n";
+ c += " s++;\n";
+ c += " } while (s < args.slices_x32);\n";
+ c += "\n";
+ c += " __local float4 tmp[8];\n";
+ c += " __local float* tmpx1 = (__local float*)tmp;\n";
+ c += " tmpx1[tid] = sum;\n";
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ c += " if (tid == 0) {\n";
+ c += " sum = dot((float4)(1.0f), tmp[0]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[1]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[2]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[3]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[4]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[5]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[6]);\n";
+ c += " sum += dot((float4)(1.0f), tmp[7]);\n";
+ c += " tmpx1[0] = 1.0f / sum;\n";
+ c += " }\n";
+ c += " barrier(CLK_LOCAL_MEM_FENCE);\n";
+ c += " sum = tmpx1[0];\n";
+ c += "\n";
+ c += " offset = 0;\n";
+ c += " s = 0;\n";
+ c += " do {\n";
+ c += " int z = offset + tid;\n";
+ c += " if (z < args.dst_tensor.Slices()) {\n";
+ c += " FLT4 res = TO_FLT4(exp(args.src_tensor.Read<float>(0, 0, "
+ "z))*sum);\n";
+ c += " args.dst_tensor.Write(res, 0, 0, z);\n";
+ c += " offset += 32;\n";
+ c += " }\n";
+ c += " s++;\n";
+ c += " } while (s < args.slices_x32);\n";
+ c += "}\n";
+ return c;
+}
+
+absl::Status Softmax1x1::BindArguments(ArgumentsBinder *args)
+{
+ float4 mask = GetMaskForLastPlane(src_[0]->Channels());
+ RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
+ RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
+ RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
+ RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w));
+ RETURN_IF_ERROR(args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32)));
+ return absl::OkStatus();
+}
+
+int3 Softmax1x1::GetGridSize() const { return int3(32, dst_[0]->Batch(), 1); }
+
+Softmax1x1 CreateSoftmax1x1(const OperationDef &definition) { return Softmax1x1(definition); }
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX1X1_H__
+#define __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX1X1_H__
+
+#include "GpuOperation.h"
+
+#include "open_cl/Precision.h"
+#include "open_cl/ClKernel.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+class Softmax1x1 : public GPUOperation
+{
+public:
+ Softmax1x1() = default;
+ explicit Softmax1x1(const OperationDef &definition);
+
+ absl::Status BindArguments(ArgumentsBinder *args) override;
+ int3 GetGridSize() const override;
+
+ // Move only
+ Softmax1x1(Softmax1x1 &&kernel);
+ Softmax1x1 &operator=(Softmax1x1 &&kernel);
+ Softmax1x1(const Softmax1x1 &) = delete;
+ Softmax1x1 &operator=(const Softmax1x1 &) = delete;
+
+ friend Softmax1x1 CreateSoftmax1x1();
+
+private:
+ std::string GetSoftmaxKernelCode(const OperationDef &op_def);
+};
+
+Softmax1x1 CreateSoftmax1x1(const OperationDef &definition);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPEN_CL_KERNELS_SOFTMAX1X1_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_KERNELS_TUNING_PARAMETERS_H__
+#define __ONERT_BACKEND_GPU_CL_KERNELS_TUNING_PARAMETERS_H__
+
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/DeviceInfo.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+enum class TuningType
+{
+ EXHAUSTIVE,
+ FAST
+};
+
+struct TuningParameters
+{
+ ProfilingCommandQueue *queue;
+ const DeviceInfo *info;
+ TuningType tuning_type = TuningType::EXHAUSTIVE;
+};
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_KERNELS_TUNING_PARAMETERS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Util.h"
+
+#include <cfloat>
+#include <cmath>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/substitute.h"
+#include "open_cl/Precision.h"
+#include "open_cl/DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::string GetCommonDefines(CalculationsPrecision precision)
+{
+ std::string result;
+
+ switch (precision)
+ {
+ case CalculationsPrecision::F32:
+ result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+ result += "#define ACCUM_FLT4 float4\n";
+ result += "#define FLT float\n";
+ result += "#define FLT2 float2\n";
+ result += "#define FLT3 float3\n";
+ result += "#define FLT4 float4\n";
+ result += "#define TO_FLT4 convert_float4\n";
+ result += "#define TO_ACCUM_TYPE convert_float4\n";
+ result += "#define TO_ACCUM_FLT convert_float\n";
+ break;
+ case CalculationsPrecision::F16:
+ result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+ result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+ result += "#define ACCUM_FLT4 half4\n";
+ result += "#define FLT half\n";
+ result += "#define FLT2 half2\n";
+ result += "#define FLT3 half3\n";
+ result += "#define FLT4 half4\n";
+ result += "#define TO_FLT4 convert_half4\n";
+ result += "#define TO_ACCUM_TYPE convert_half4\n";
+ result += "#define TO_ACCUM_FLT convert_half\n";
+ break;
+ case CalculationsPrecision::F32_F16:
+ result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
+ result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
+ result += "#define ACCUM_FLT4 float4\n";
+ result += "#define FLT half\n";
+ result += "#define FLT2 half2\n";
+ result += "#define FLT3 half3\n";
+ result += "#define FLT4 half4\n";
+ result += "#define TO_FLT4 convert_half4\n";
+ result += "#define TO_ACCUM_TYPE convert_float4\n";
+ result += "#define TO_ACCUM_FLT convert_float\n";
+ break;
+ }
+ return result;
+}
+
+std::string GetXStrideCorrectedV2(const std::string &src_x, const std::string &batch_size,
+ const std::string &stride_x, const std::string &padding_x)
+{
+ // int p0 = src_x / batch_size;\n";
+ // int b0 = src_x % batch_size;\n";
+ // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
+ return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x, batch_size, stride_x,
+ padding_x);
+}
+
+float4 GetMaskForLastPlane(int channels)
+{
+ float4 mask = float4(0.0f);
+ const int reminder = channels % 4 == 0 ? 4 : channels % 4;
+ for (int i = 0; i < reminder; ++i)
+ {
+ mask[i] = 1.0f;
+ }
+ return mask;
+}
+
+int3 GetFirstSuitableWorkGroup(const std::vector<int3> &wgs, int max_wg_size)
+{
+ for (const auto &wg : wgs)
+ {
+ const int wg_size = wg.x * wg.y * wg.z;
+ if (wg_size <= max_wg_size)
+ {
+ return wg;
+ }
+ }
+ return {1, 1, 1};
+}
+
+int GetRecommendedBlockSizeForConv(const DeviceInfo &device_info, CalculationsPrecision precision,
+ int task_size)
+{
+ const float task_size_per_cu = task_size / static_cast<float>(device_info.compute_units_count);
+ int block_size = 1;
+ float threshold_1 = FLT_MAX;
+ float threshold_2 = FLT_MAX;
+ float threshold_4 = FLT_MAX;
+ if (!device_info.IsMali())
+ {
+ return 1;
+ }
+ MaliInfo mali_info = device_info.mali_info;
+ switch (precision)
+ {
+ case CalculationsPrecision::F16:
+ if (mali_info.IsBifrostGen1())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 4.0f;
+ threshold_4 = 256.0f * 8.0f;
+ }
+ else if (mali_info.IsBifrostGen2())
+ {
+ threshold_1 = 256.0f * 2.0f;
+ threshold_2 = 256.0f * 8.0f;
+ threshold_4 = 256.0f * 16.0f;
+ }
+ else if (mali_info.IsBifrostGen3() || mali_info.IsValhall())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 6.0f;
+ threshold_4 = 256.0f * 16.0f;
+ }
+ else if (mali_info.IsMidgard())
+ {
+ threshold_1 = 256.0f * 4.0f;
+ threshold_2 = 256.0f * 16.0f;
+ }
+ break;
+ case CalculationsPrecision::F32_F16:
+ if (mali_info.IsBifrostGen1())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 3.0f;
+ threshold_4 = 256.0f * 32.0f;
+ }
+ else if (mali_info.IsBifrostGen2())
+ {
+ threshold_1 = 256.0f * 2.0f;
+ threshold_2 = 256.0f * 8.0f;
+ }
+ else if (mali_info.IsBifrostGen3() || mali_info.IsValhall())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 8.0f;
+ }
+ else if (mali_info.IsMidgard())
+ {
+ threshold_1 = 256.0f * 4.0f;
+ }
+ break;
+ case CalculationsPrecision::F32:
+ if (mali_info.IsBifrostGen1())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 4.0f;
+ }
+ else if (mali_info.IsBifrostGen2())
+ {
+ threshold_1 = 128.0f;
+ threshold_2 = 256.0f * 4.0f;
+ }
+ else if (mali_info.IsBifrostGen3() || mali_info.IsValhall())
+ {
+ threshold_1 = 256.0f;
+ threshold_2 = 256.0f * 12.0f;
+ }
+ else if (mali_info.IsMidgard())
+ {
+ threshold_1 = 256.0f * 16.0f;
+ }
+ break;
+ }
+ if (task_size_per_cu <= threshold_1)
+ {
+ block_size = 1;
+ }
+ else if (task_size_per_cu <= threshold_2)
+ {
+ block_size = 2;
+ }
+ else if (task_size_per_cu <= threshold_4)
+ {
+ block_size = 4;
+ }
+ else
+ {
+ block_size = 8;
+ }
+ return block_size;
+}
+
+int3 GetWorkGroupsCount(const int3 &grid_size, const int3 &work_group_size)
+{
+ int3 work_groups_count;
+ work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
+ work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
+ work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
+ return work_groups_count;
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_UTIL_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_UTIL_H__
+
+#include <string>
+#include <vector>
+
+#include "open_cl/DeviceInfo.h"
+#include "open_cl/Precision.h"
+#include "open_cl/DataType.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/Types.h"
+#include "open_cl/Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::string GetCommonDefines(CalculationsPrecision precision);
+
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrected(const std::string &src_x, const std::string &batch_size,
+ const std::string &stride_x, const std::string &padding_x);
+
+// Calculates correct X coordinate when stride != 1 and batch != 1 for layouts
+// with B after W (for example HWBC4) and WB stored in one axis of GPU
+// resources.
+std::string GetXStrideCorrectedV2(const std::string &src_x, const std::string &batch_size,
+ const std::string &stride_x, const std::string &padding_x);
+
+// Returns float4 mask for last plane(batch of 4 channels)
+// assumes that plane size is 4;
+// for example we have 7 channels, in our data structures we align it to 8
+// but 8s-channel will be empty, then last plane (batch of 4 channels) will
+// have this mask (1, 1, 1, 0).
+float4 GetMaskForLastPlane(int channels);
+
+// returns first work group from wgs that has size not bigger than max_wg_size
+// if no suitable groups among wgs, returns {1, 1, 1}
+int3 GetFirstSuitableWorkGroup(const std::vector<int3> &wgs, int max_wg_size);
+
+// task_size as amount of FLT4 processed elements.
+int GetRecommendedBlockSizeForConv(const DeviceInfo &device, CalculationsPrecision precision,
+ int task_size);
+
+int3 GetWorkGroupsCount(const int3 &grid_size, const int3 &work_group_size);
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_UTIL_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WorkGroupPicking.h"
+
+#include <algorithm>
+#include <limits>
+#include <set>
+#include <vector>
+
+#include "open_cl/Util.h"
+#include "open_cl/Types.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+namespace
+{
+
+std::vector<int2> Get2DWorkgroupsEqualTo128()
+{
+ return {{128, 1}, {64, 2}, {32, 4}, {16, 8}, {8, 16}, {4, 32}, {2, 64}, {1, 128}};
+}
+
+std::vector<int3> GenerateWorkGroupSizesXYMultipleOf(int multiplier, int3 grid,
+ const KernelInfo &kernel_info,
+ const DeviceInfo &device_info,
+ WorkGroupSizeAlignment z_alignment)
+{
+ std::vector<int3> work_groups;
+ work_groups.reserve(32);
+
+ std::vector<int> possible_z_sizes = GetPossibleSizes(grid.z, z_alignment);
+
+ for (int x = 1; x <= kernel_info.max_work_group_size; x *= 2)
+ {
+ for (int y = 1; y <= kernel_info.max_work_group_size; y *= 2)
+ {
+ int work_group_size_xy = x * y;
+ if (work_group_size_xy % multiplier != 0 ||
+ work_group_size_xy > kernel_info.max_work_group_size)
+ {
+ continue;
+ }
+ for (auto z : possible_z_sizes)
+ {
+ if (work_group_size_xy * z > kernel_info.max_work_group_size)
+ {
+ continue;
+ }
+ if (x <= device_info.max_work_group_size_x && y <= device_info.max_work_group_size_y &&
+ z <= device_info.max_work_group_size_z)
+ {
+ work_groups.push_back({x, y, z});
+ }
+ }
+ }
+ }
+ return work_groups;
+}
+
+std::vector<int3> GenerateWorkGroupSizesXMultipleOf(int multiplier, int3 grid,
+ const KernelInfo &kernel_info,
+ const DeviceInfo &device_info,
+ WorkGroupSizeAlignment z_alignment)
+{
+ std::vector<int3> work_groups;
+ work_groups.reserve(32);
+
+ std::vector<int> possible_z_sizes = GetPossibleSizes(grid.z, z_alignment);
+ std::vector<int> possible_y_sizes = GetPossibleSizes(grid.y, WorkGroupSizeAlignment::PRECISE);
+
+ for (int x = multiplier; x <= kernel_info.max_work_group_size && x < grid.x + multiplier;
+ x += multiplier)
+ {
+ for (auto y : possible_y_sizes)
+ {
+ for (auto z : possible_z_sizes)
+ {
+ if (x <= device_info.max_work_group_size_x && y <= device_info.max_work_group_size_y &&
+ z <= device_info.max_work_group_size_z && x * y * z <= kernel_info.max_work_group_size)
+ {
+ work_groups.push_back({x, y, z});
+ }
+ }
+ }
+ }
+ return work_groups;
+}
+
+void GetWorkGroupsAlignedToGrid(const DeviceInfo &device_info, const KernelInfo &kernel_info,
+ const int3 &grid, std::vector<int3> *work_groups)
+{
+ int3 max_wg_size;
+ max_wg_size.x = device_info.max_work_group_size_x;
+ max_wg_size.y = device_info.max_work_group_size_y;
+ max_wg_size.z = device_info.max_work_group_size_z;
+ GenerateWorkGroupSizesAlignedToGrid(grid, max_wg_size, kernel_info.max_work_group_size,
+ work_groups);
+}
+
+int GetPenalty(int grid_size, int group_size)
+{
+ const int reminder = grid_size % group_size;
+ return reminder == 0 ? 0 : group_size - reminder;
+}
+
+int GetPenalty(int2 grid_size, int2 group_size)
+{
+ const int p_x = GetPenalty(grid_size.x, group_size.x);
+ const int p_y = GetPenalty(grid_size.y, group_size.y);
+ return p_x * grid_size.y + p_y * grid_size.x + p_x * p_y;
+}
+
+int GetMaxSizeWithMinPenalty(int size, int max_size)
+{
+ int best_size = 128;
+ int min_penalty = GetPenalty(size, best_size);
+ for (int i = 2; i * 128 <= max_size; ++i)
+ {
+ if (GetPenalty(size, i * 128) == min_penalty)
+ {
+ best_size = i * 128;
+ }
+ }
+ return best_size;
+}
+
+int2 GetMaxSizeWithMinPenalty(int2 size, int max_size)
+{
+ std::vector<int2> base_groups = Get2DWorkgroupsEqualTo128();
+ int min_penalty = std::numeric_limits<int>::max();
+ for (const auto &group : base_groups)
+ {
+ min_penalty = std::min(GetPenalty(size, group), min_penalty);
+ }
+ for (const auto &group : base_groups)
+ {
+ for (int y = 1; y * group.y <= max_size; ++y)
+ {
+ int new_group_y = y * group.y;
+ for (int x = 1; x * group.x <= max_size; ++x)
+ {
+ int new_group_x = x * group.x;
+ if (new_group_x * new_group_y > max_size)
+ {
+ break;
+ }
+ if (GetPenalty(size, int2(new_group_x, new_group_y)) == min_penalty)
+ {
+ return int2(new_group_x, new_group_y);
+ }
+ }
+ }
+ }
+ return int2(0, 0);
+}
+
+int GetBiggestDividerWithPriority(int number, int max_divider)
+{
+ if (number % 8 == 0 && 8 <= max_divider)
+ {
+ return 8;
+ }
+ if (number % 4 == 0 && 4 <= max_divider)
+ {
+ return 4;
+ }
+ if (number % 2 == 0 && 2 <= max_divider)
+ {
+ return 2;
+ }
+ for (int i = max_divider; i != 0; i--)
+ {
+ if (number % i == 0)
+ {
+ return i;
+ }
+ }
+ return 1;
+}
+
+int GetBiggestDivider(int number, int max_divider)
+{
+ for (int i = max_divider; i != 0; i--)
+ {
+ if (number % i == 0)
+ {
+ return i;
+ }
+ }
+ return 1;
+}
+
+} // namespace
+
+int3 GetWorkGroupXY128ConvLinear(const int3 &grid)
+{
+ int grid_z = GetBiggestDividerWithPriority(grid.z, 4);
+ if (grid.x <= 128)
+ {
+ return int3(128, 1, grid_z);
+ }
+ int grid_x = GetMaxSizeWithMinPenalty(grid.x, 512 / grid_z);
+ return {grid_x, 1, grid_z};
+}
+
+int3 GetWorkGroupXY128Conv(const int3 &grid)
+{
+ int grid_z = GetBiggestDividerWithPriority(grid.z, 4);
+ if (grid.x <= 16 && grid.y <= 8)
+ {
+ return int3(16, 8, grid_z);
+ }
+ int2 grid_xy = GetMaxSizeWithMinPenalty(int2(grid.x, grid.y), 512 / grid_z);
+ return int3(grid_xy.x, grid_xy.y, grid_z);
+}
+
+// int3 GetWorkGroupXY128Simple(const int3& grid) { return int3(16, 8, 1); }
+
+int3 GetWorkGroup(const int3 &grid, int max_size)
+{
+ int wg_z = GetBiggestDividerWithPriority(grid.z, 8);
+ int wg_xy_size = max_size / wg_z;
+ int wg_x = std::min(DivideRoundUp(grid.x, 2), wg_xy_size);
+ int wg_y = std::min(wg_xy_size / wg_x, grid.y);
+ return int3(wg_x, wg_y, wg_z);
+}
+
+int3 GetWorkGroupConv(const int3 &grid, int max_size, int max_z_size)
+{
+ int wg_z = GetBiggestDivider(grid.z, max_z_size);
+ int wg_xy_size = std::min(256, max_size) / wg_z;
+ int wg_x = std::min(grid.x, wg_xy_size);
+ int wg_y = std::min(wg_xy_size / wg_x, grid.y);
+ if (wg_y == grid.y && grid.y % 2 == 0)
+ {
+ wg_y = grid.y / 2;
+ }
+ return int3(wg_x, wg_y, wg_z);
+}
+
+void GetPossibleWorkGroupsXYMultipleOf(int multiplier, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ WorkGroupSizeAlignment z_alignment,
+ std::vector<int3> *work_groups)
+{
+ *work_groups =
+ GenerateWorkGroupSizesXYMultipleOf(multiplier, grid, kernel_info, device_info, z_alignment);
+}
+
+void GetPossibleWorkGroupsXMultipleOf(int multiplier, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ WorkGroupSizeAlignment z_alignment,
+ std::vector<int3> *work_groups)
+{
+ *work_groups =
+ GenerateWorkGroupSizesXMultipleOf(multiplier, grid, kernel_info, device_info, z_alignment);
+}
+
+bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height)
+{
+ int planar_work_groups = DivideRoundUp(width * height, 128);
+ auto base_work_groups = Get2DWorkgroupsEqualTo128();
+ bool have_equal_work_groups = false;
+ for (auto &work_group : base_work_groups)
+ {
+ int x_groups = DivideRoundUp(width, work_group.x);
+ int y_groups = DivideRoundUp(height, work_group.y);
+ int xy_groups = x_groups * y_groups;
+ if (xy_groups == planar_work_groups)
+ {
+ have_equal_work_groups = true;
+ break;
+ }
+ }
+ return !have_equal_work_groups;
+}
+
+void GetPossibleWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ std::vector<int3> *work_groups)
+{
+ switch (tuning_type)
+ {
+ case TuningType::FAST:
+ work_groups->push_back(GetWorkGroup(grid, kernel_info.max_work_group_size));
+ return;
+ case TuningType::EXHAUSTIVE:
+ {
+ GetWorkGroupsAlignedToGrid(device_info, kernel_info, grid, work_groups);
+ return;
+ }
+ default:
+ work_groups->push_back({8, 4, 1});
+ return;
+ }
+}
+
+void GetPossibleWorkGroupsConv(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ std::vector<int3> *work_groups)
+{
+ switch (tuning_type)
+ {
+ case TuningType::FAST:
+ {
+ int max_z_size = 16;
+ if (device_info.IsAdreno())
+ {
+ max_z_size = device_info.IsAdreno3xx() ? 16 : 64;
+ }
+ max_z_size = std::min(max_z_size, device_info.max_work_group_size_z);
+ work_groups->push_back(GetWorkGroupConv(grid, kernel_info.max_work_group_size, max_z_size));
+ return;
+ }
+ case TuningType::EXHAUSTIVE:
+ {
+ GetWorkGroupsAlignedToGrid(device_info, kernel_info, grid, work_groups);
+ return;
+ }
+ default:
+ work_groups->push_back({8, 4, 1});
+ return;
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_WROK_GROUP_PICKING_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_WROK_GROUP_PICKING_H__
+
+#include <vector>
+
+#include "TuningParameters.h"
+
+#include "open_cl/ClKernel.h"
+#include "open_cl/DeviceInfo.h"
+#include "open_cl/Types.h"
+#include "open_cl/WorkgroupSelection.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+// multiplier can be power of two only
+void GetPossibleWorkGroupsXYMultipleOf(int multiplier, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ WorkGroupSizeAlignment z_alignment,
+ std::vector<int3> *work_groups);
+
+void GetPossibleWorkGroupsXMultipleOf(int multiplier, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ WorkGroupSizeAlignment z_alignment,
+ std::vector<int3> *work_groups);
+
+int3 GetWorkGroupXY128ConvLinear(const int3 &grid);
+
+int3 GetWorkGroupXY128Simple(const int3 &grid);
+int3 GetWorkGroupXY128Conv(const int3 &grid);
+
+bool XY128RequiresMoreWorkGroupsThenXY128Linear(int width, int height);
+
+void GetPossibleWorkGroups(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ std::vector<int3> *work_groups);
+
+void GetPossibleWorkGroupsConv(TuningType tuning_type, const DeviceInfo &device_info,
+ const KernelInfo &kernel_info, const int3 &grid,
+ std::vector<int3> *work_groups);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_KERNELS_WROK_GROUP_PICKING_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "ConvolutionSelector.h"
+
+#include "absl/memory/memory.h"
+#include "open_cl/kernels/ConvBuffer1x1.h"
+#include "open_cl/kernels/ConvConstants.h"
+#include "open_cl/kernels/ConvPowervr.h"
+#include "open_cl/kernels/ConvWeightsConverter.h"
+#include "open_cl/kernels/WorkGroupPicking.h"
+#include "open_cl/TensorType.h"
+#include "open_cl/Util.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::unique_ptr<GPUOperation> SelectConvolutionAdreno(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints)
+{
+ if (IsConvConstantsSupported(device_info, op_def, attr))
+ {
+ GPUOperation conv = CreateConvConstants(device_info, op_def, attr);
+ return absl::make_unique<GPUOperation>(std::move(conv));
+ }
+ else
+ {
+ ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionWinogradAdreno(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def,
+ ModelHints)
+{
+ ConvPowerVR conv = CreateConvPowerVRWino4x4To6x6(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+}
+
+std::unique_ptr<GPUOperation>
+SelectConvolutionDynamicWeightsAdreno(const Convolution2DAttributes &attr,
+ const BHWC &weights_shape, const BHWC &dst_shape,
+ const DeviceInfo &device_info, const OperationDef &op_def,
+ ModelHints, ConvWeightsDescription *weights_desc)
+{
+ ConvPowerVR conv =
+ CreateConvPowerVRDynamicWeights(device_info, op_def, attr, weights_shape, &dst_shape);
+ *weights_desc = conv.GetConvWeightsDescription();
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionNVidia(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ if (IsConvConstantsSupported(device_info, op_def, attr))
+ {
+ GPUOperation conv = CreateConvConstants(device_info, op_def, attr);
+ return absl::make_unique<GPUOperation>(std::move(conv));
+ }
+ else
+ {
+ ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionPowerVR(const Convolution2DAttributes &attr,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionMali(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER &&
+ IsConvBuffer1x1Supported(op_def, attr))
+ {
+ ConvBuffer1x1 conv = CreateConvBuffer1x1(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvBuffer1x1>(std::move(conv));
+ }
+ else
+ {
+ ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionWinogradMali(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER)
+ {
+ ConvBuffer1x1 conv = CreateConvBuffer1x1Wino4x4To6x6(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvBuffer1x1>(std::move(conv));
+ }
+ else
+ {
+ ConvPowerVR conv = CreateConvPowerVRWino4x4To6x6(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+std::unique_ptr<GPUOperation>
+SelectConvolutionDynamicWeightsMali(const Convolution2DAttributes &attr, const BHWC &weights_shape,
+ const BHWC &dst_shape, const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints,
+ ConvWeightsDescription *weights_desc)
+{
+ if (op_def.src_tensors[0].storage_type == TensorStorageType::BUFFER &&
+ IsConvBuffer1x1Supported(op_def, weights_shape, attr))
+ {
+ ConvBuffer1x1 conv =
+ CreateConvBuffer1x1DynamicWeights(device_info, op_def, attr, weights_shape, &dst_shape);
+ *weights_desc = conv.GetConvWeightsDescription();
+ return absl::make_unique<ConvBuffer1x1>(std::move(conv));
+ }
+ else
+ {
+ ConvPowerVR conv =
+ CreateConvPowerVRDynamicWeights(device_info, op_def, attr, weights_shape, &dst_shape);
+ *weights_desc = conv.GetConvWeightsDescription();
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+} // namespace
+
+std::unique_ptr<GPUOperation> SelectConvolution(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints hints)
+{
+ if (device_info.IsAdreno())
+ {
+ return SelectConvolutionAdreno(attr, dst_shape, device_info, op_def, hints);
+ }
+ else if (device_info.IsPowerVR() || device_info.IsAMD() || device_info.IsIntel())
+ {
+ return SelectConvolutionPowerVR(attr, device_info, op_def);
+ }
+ else if (device_info.IsNvidia())
+ {
+ return SelectConvolutionNVidia(attr, dst_shape, device_info, op_def);
+ }
+ else if (device_info.IsMali())
+ {
+ return SelectConvolutionMali(attr, dst_shape, device_info, op_def);
+ }
+ else
+ {
+ return SelectConvolutionAdreno(attr, dst_shape, device_info, op_def, hints);
+ }
+}
+
+std::unique_ptr<GPUOperation> SelectConvolutionForWinograd(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def,
+ ModelHints hints)
+{
+ if (device_info.IsAdreno())
+ {
+ return SelectConvolutionWinogradAdreno(attr, dst_shape, device_info, op_def, hints);
+ }
+ else if (device_info.IsPowerVR() || device_info.IsAMD() || device_info.IsNvidia() ||
+ device_info.IsIntel())
+ {
+ ConvPowerVR conv = CreateConvPowerVRWino4x4To6x6(device_info, op_def, attr, &dst_shape);
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+ else if (device_info.IsMali())
+ {
+ return SelectConvolutionWinogradMali(attr, dst_shape, device_info, op_def);
+ }
+ else
+ {
+ return SelectConvolutionWinogradAdreno(attr, dst_shape, device_info, op_def, hints);
+ }
+}
+
+std::unique_ptr<GPUOperation>
+SelectConvolutionWithDynamicWeights(const Convolution2DAttributes &attr, const BHWC &weights_shape,
+ const BHWC &dst_shape, const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints hints,
+ ConvWeightsDescription *weights_desc)
+{
+ if (device_info.IsAdreno())
+ {
+ return SelectConvolutionDynamicWeightsAdreno(attr, weights_shape, dst_shape, device_info,
+ op_def, hints, weights_desc);
+ }
+ else if (device_info.IsMali())
+ {
+ return SelectConvolutionDynamicWeightsMali(attr, weights_shape, dst_shape, device_info, op_def,
+ hints, weights_desc);
+ }
+ else
+ {
+ ConvPowerVR conv =
+ CreateConvPowerVRDynamicWeights(device_info, op_def, attr, weights_shape, &dst_shape);
+ *weights_desc = conv.GetConvWeightsDescription();
+ return absl::make_unique<ConvPowerVR>(std::move(conv));
+ }
+}
+
+std::unique_ptr<GPUOperation>
+SelectConverterToConvWeights(const ConvWeightsDescription &weights_desc, const OperationDef &op_def,
+ ModelHints)
+{
+ ConverterToConvWeights converter = ConverterToConvWeights(op_def, weights_desc);
+ return absl::make_unique<ConverterToConvWeights>(std::move(converter));
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_CONVOLUTION_SELECTOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_CONVOLUTION_SELECTOR_H__
+
+#include <memory>
+
+#include "open_cl/kernels/ConvCommon.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/ModelHints.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::unique_ptr<GPUOperation> SelectConvolution(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints hints);
+
+std::unique_ptr<GPUOperation> SelectConvolutionForWinograd(const Convolution2DAttributes &attr,
+ const BHWC &dst_shape,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def,
+ ModelHints hints);
+
+std::unique_ptr<GPUOperation>
+SelectConvolutionWithDynamicWeights(const Convolution2DAttributes &attr, const BHWC &weights_shape,
+ const BHWC &dst_shape, const DeviceInfo &device_info,
+ const OperationDef &op_def, ModelHints hints,
+ ConvWeightsDescription *weights_desc);
+
+std::unique_ptr<GPUOperation>
+SelectConverterToConvWeights(const ConvWeightsDescription &weights_desc, const OperationDef &op_def,
+ ModelHints hints);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_CONVOLUTION_SELECTOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DwConvolutionSelector.h"
+
+#include "absl/memory/memory.h"
+#include "open_cl/ClDevice.h"
+#include "open_cl/kernels/DepthwiseConv.h"
+#include "open_cl/kernels/DepthwiseConv3x3.h"
+#include "open_cl/Precision.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace
+{
+
+std::unique_ptr<GPUOperation>
+SelectDWConvolutionAdreno(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info, const OperationDef &op_def)
+{
+ if (IsDepthwiseConv3x3Supported(attr))
+ {
+ return absl::make_unique<DepthwiseConv3x3>(CreateDepthwiseConv3x3(device_info, op_def, attr));
+ }
+ else
+ {
+ return absl::make_unique<GPUOperation>(CreateDepthwiseConvolution2D(device_info, op_def, attr));
+ }
+}
+
+std::unique_ptr<GPUOperation>
+SelectDWConvolutionPowerVR(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info, const OperationDef &op_def)
+{
+ if (IsDepthwiseConv3x3Supported(attr))
+ {
+ return absl::make_unique<DepthwiseConv3x3>(CreateDepthwiseConv3x3(device_info, op_def, attr));
+ }
+ else
+ {
+ return absl::make_unique<GPUOperation>(CreateDepthwiseConvolution2D(device_info, op_def, attr));
+ }
+}
+
+std::unique_ptr<GPUOperation> SelectDWConvolutionMali(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ const auto storage_type = op_def.src_tensors[0].storage_type;
+ bool buffer_type =
+ storage_type == TensorStorageType::BUFFER || storage_type == TensorStorageType::IMAGE_BUFFER;
+ const MaliInfo mali_info = device_info.mali_info;
+ if (IsDepthwiseConv3x3Supported(attr) && !mali_info.IsMidgard() && !buffer_type &&
+ op_def.precision != CalculationsPrecision::F32)
+ {
+ return absl::make_unique<DepthwiseConv3x3>(CreateDepthwiseConv3x3(device_info, op_def, attr));
+ }
+ else
+ {
+ return absl::make_unique<GPUOperation>(CreateDepthwiseConvolution2D(device_info, op_def, attr));
+ }
+}
+} // namespace
+
+std::unique_ptr<GPUOperation> SelectDWConvolution(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def)
+{
+ if (device_info.IsAdreno())
+ {
+ return SelectDWConvolutionAdreno(attr, device_info, op_def);
+ }
+ else if (device_info.IsPowerVR())
+ {
+ return SelectDWConvolutionPowerVR(attr, device_info, op_def);
+ }
+ else if (device_info.IsMali())
+ {
+ return SelectDWConvolutionMali(attr, device_info, op_def);
+ }
+ else
+ {
+ return SelectDWConvolutionAdreno(attr, device_info, op_def);
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_DW_CONVOLUTION_SELECTOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_DW_CONVOLUTION_SELECTOR_H__
+
+#include <memory>
+
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Status.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+std::unique_ptr<GPUOperation> SelectDWConvolution(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info,
+ const OperationDef &op_def);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_DW_CONVOLUTION_SELECTOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SimpleSelectors.h"
+
+#include <memory>
+#include <set>
+
+#include "open_cl/kernels/Add.h"
+#include "open_cl/kernels/DepthwiseConv.h"
+#include "open_cl/kernels/Pooling.h"
+#include "open_cl/kernels/Relu.h"
+#include "open_cl/kernels/Reshape.h"
+#include "open_cl/kernels/Reshapex4.h"
+#include "open_cl/kernels/Softmax.h"
+#include "open_cl/kernels/Softmax1x1.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+void SelectAdd(const OperationDef &op_def, const std::vector<int> &channels, int dst_channels,
+ std::unique_ptr<GPUOperation> *ptr)
+{
+ GPUOperation operation = CreateAdd(op_def, channels, dst_channels);
+ *ptr = std::make_unique<GPUOperation>(std::move(operation));
+}
+
+std::unique_ptr<GPUOperation>
+SelectDWConvolutionDynamicWeights(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info, const OperationDef &op_def)
+{
+ return absl::make_unique<GPUOperation>(
+ CreateDepthwiseConvolution2DDynamicWeights(device_info, op_def, attr));
+}
+
+std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes &attr,
+ const OperationDef &op_def)
+{
+ GPUOperation operation = CreatePooling(op_def, attr);
+ return absl::make_unique<GPUOperation>(std::move(operation));
+}
+
+std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes &attr, const OperationDef &op_def)
+{
+ return absl::make_unique<GPUOperation>(CreateReLU(op_def, attr));
+}
+
+void SelectReshape(int src_channels, int dst_channels, const OperationDef &op_def,
+ std::unique_ptr<GPUOperation> *ptr)
+{
+ if (src_channels % 4 == 0 && dst_channels % 4 == 0)
+ {
+ GPUOperation operation = CreateReshapex4(op_def);
+ *ptr = std::make_unique<GPUOperation>(std::move(operation));
+ }
+ else
+ {
+ GPUOperation operation = CreateReshape(op_def);
+ *ptr = std::make_unique<GPUOperation>(std::move(operation));
+ }
+}
+
+void SelectSoftmax(const BHWC &shape, const OperationDef &op_def,
+ std::unique_ptr<GPUOperation> *ptr)
+{
+ if (shape.w == 1 && shape.h == 1)
+ {
+ Softmax1x1 operation = CreateSoftmax1x1(op_def);
+ *ptr = absl::make_unique<Softmax1x1>(std::move(operation));
+ }
+ else
+ {
+ GPUOperation operation = CreateSoftmax(op_def);
+ *ptr = absl::make_unique<GPUOperation>(std::move(operation));
+ }
+}
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_SIMPLE_SELECTORS_H__
+#define __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_SIMPLE_SELECTORS_H__
+
+#include <memory>
+
+#include "open_cl/ClDevice.h"
+#include "open_cl/kernels/GpuOperation.h"
+#include "open_cl/Operations.h"
+#include "open_cl/Shape.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+
+void SelectAdd(const OperationDef &op_def, const std::vector<int> &channels, int dst_channels,
+ std::unique_ptr<GPUOperation> *ptr);
+
+std::unique_ptr<GPUOperation>
+SelectDWConvolutionDynamicWeights(const DepthwiseConvolution2DAttributes &attr,
+ const DeviceInfo &device_info, const OperationDef &op_def);
+
+std::unique_ptr<GPUOperation> SelectPooling(const Pooling2DAttributes &attr,
+ const OperationDef &op_def);
+
+std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes &attr, const OperationDef &op_def);
+
+void SelectReshape(int src_channels, int dst_channels, const OperationDef &op_def,
+ std::unique_ptr<GPUOperation> *ptr);
+
+void SelectSoftmax(const BHWC &shape, const OperationDef &op_def,
+ std::unique_ptr<GPUOperation> *ptr);
+
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPENCL_SELECTORS_SIMPLE_SELECTORS_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CLTensor.h"
+
+#include "open_cl/Buffer.h"
+#include "open_cl/ClContext.h"
+#include "open_cl/Tensor.h"
+#include "open_cl/TensorType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace operand
+{
+
+CLTensor::CLTensor(size_t rank, ir::Shape shape, std::shared_ptr<Environment> environment)
+ : ICLTensor{rank, shape, environment}, _tensor(std::make_shared<Tensor>())
+{
+}
+
+const Tensor *CLTensor::handle() const { return _tensor.get(); }
+
+Tensor *CLTensor::handle() { return _tensor.get(); }
+
+void CLTensor::setBuffer(void *host_ptr) { (void)host_ptr; }
+
+} // namespace operand
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPERAND_CL_TENSOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPERAND_CL_TENSOR_H__
+
+#include "ICLTensor.h"
+
+#include "open_cl/Buffer.h"
+#include "open_cl/ClContext.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace operand
+{
+
+class CLTensor : public ICLTensor
+{
+public:
+ CLTensor() = delete;
+
+public:
+ CLTensor(size_t rank, ir::Shape shape, std::shared_ptr<Environment> environment);
+
+public:
+ const Tensor *handle() const override;
+ Tensor *handle() override;
+
+public:
+ /** Set given buffer as the buffer of the tensor
+ *
+ * @note Ownership of the memory is not transferred to this object.
+ * Thus management (allocate/free) should be done by the client.
+ *
+ * @param[in] host_ptr Storage to be used.
+ */
+ void setBuffer(void *host_ptr);
+
+private:
+ std::shared_ptr<Tensor> _tensor;
+};
+
+} // namespace operand
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPERAND_CL_TENSOR_H__
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ICLTensor.h"
+
+#include "open_cl/Api.h"
+#include "open_cl/Spi.h"
+#include "open_cl/OpenclWrapper.h"
+#include "open_cl/TensorTypeUtil.h"
+#include "open_cl/kernels/Converter.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace operand
+{
+
+void ICLTensor::access(const std::function<void(ITensor &tensor)> &fn)
+{
+ if (total_size() == 0)
+ return;
+
+ fn(*this);
+}
+
+void ICLTensor::enqueueWriteBuffer(const void *ptr, bool)
+{
+ const float *arr = (float *)ptr;
+ TensorObject input_obj = MakeReadableCpuMemory(absl::MakeSpan(arr, total_size() / 4));
+
+ TensorObject output_obj;
+
+ if (handle()->GetStorageType() == TensorStorageType::BUFFER)
+ {
+ output_obj = OpenClBuffer{handle()->GetMemoryPtr()};
+ }
+ else if (handle()->GetStorageType() == TensorStorageType::IMAGE_BUFFER)
+ {
+ output_obj = OpenClBuffer{handle()->GetMemoryPtrForWriting()};
+ }
+ else
+ {
+ output_obj = OpenClTexture{handle()->GetMemoryPtr()};
+ }
+
+ TensorObjectDef input_def;
+ input_def.dimensions.b = handle()->Batch();
+ input_def.dimensions.h = handle()->Height();
+ input_def.dimensions.w = handle()->Width();
+ input_def.dimensions.c = handle()->Channels();
+ input_def.object_def.data_layout = DataLayout::BHWC;
+ input_def.object_def.data_type = DataType::FLOAT32;
+ input_def.object_def.object_type = ObjectType::CPU_MEMORY;
+ input_def.object_def.user_provided = true;
+
+ TensorObjectDef tmp_def;
+ tmp_def.dimensions.b = handle()->Batch();
+ tmp_def.dimensions.h = handle()->Height();
+ tmp_def.dimensions.w = handle()->Width();
+ tmp_def.dimensions.c = handle()->Channels();
+ tmp_def.object_def.data_layout = DataLayout::BHWC;
+ tmp_def.object_def.data_type = DataType::FLOAT32;
+ tmp_def.object_def.object_type = ToObjectType(handle()->GetStorageType());
+ tmp_def.object_def.user_provided = true;
+
+ auto dims = tmp_def.dimensions;
+ const BHWC shape(dims.b, dims.h, dims.w, dims.c);
+ const TensorDescriptor desc{
+ tmp_def.object_def.data_type,
+ ToTensorStorageType(tmp_def.object_def.object_type, tmp_def.object_def.data_layout),
+ Layout::BHWC};
+ if (!AllocateTensorMemory(_environment->context(), shape, desc, &_cl_memory).ok())
+ {
+ throw std::runtime_error("AllocateTensorMemory error.");
+ }
+ TensorObject tmp_obj;
+ if (tmp_def.object_def.object_type == ObjectType::OPENCL_TEXTURE)
+ {
+ tmp_obj = OpenClTexture{_cl_memory.memory()};
+ }
+ else
+ {
+ tmp_obj = OpenClBuffer{_cl_memory.memory()};
+ }
+
+ TensorObjectDef output_def = input_def;
+ output_def.dimensions.b = handle()->Batch();
+ output_def.dimensions.h = handle()->Height();
+ output_def.dimensions.w = handle()->Width();
+ output_def.dimensions.c = handle()->Channels();
+ output_def.object_def.data_layout = ToDataLayout(handle()->GetStorageType());
+ output_def.object_def.data_type = handle()->GetDataType();
+ output_def.object_def.object_type = ToObjectType(handle()->GetStorageType());
+
+ _converter_builder = NewConverterBuilder(_environment.get());
+ if (!_converter_builder->MakeConverter(input_def, tmp_def, &_converter_cpu).ok())
+ {
+ throw std::runtime_error("MakeConverter<_converter_cpu> error.");
+ }
+ if (!_converter_builder->MakeConverter(tmp_def, output_def, &_converter_bhwc).ok())
+ {
+ throw std::runtime_error("MakeConverter<_converter_bhwc> error.");
+ }
+
+ if (!_converter_cpu->Convert(input_obj, tmp_obj).ok())
+ {
+ throw std::runtime_error("[w] _converter_cpu Convert error.");
+ }
+ if (!_converter_bhwc->Convert(tmp_obj, output_obj).ok())
+ {
+ throw std::runtime_error("[w] _converter_bhwc Convert error.");
+ }
+}
+
+void ICLTensor::enqueueReadBuffer(void *ptr, bool)
+{
+ float *arr = (float *)ptr;
+ TensorObject output_obj = MakeCpuMemory(absl::MakeSpan(arr, total_size() / 4));
+
+ TensorObject input_obj;
+
+ if (handle()->GetStorageType() == TensorStorageType::BUFFER)
+ {
+ input_obj = OpenClBuffer{handle()->GetMemoryPtr()};
+ }
+ else if (handle()->GetStorageType() == TensorStorageType::IMAGE_BUFFER)
+ {
+ input_obj = OpenClBuffer{handle()->GetMemoryPtrForWriting()};
+ }
+ else
+ {
+ input_obj = OpenClTexture{handle()->GetMemoryPtr()};
+ }
+
+ TensorObjectDef input_def;
+ input_def.dimensions.b = handle()->Batch();
+ input_def.dimensions.h = handle()->Height();
+ input_def.dimensions.w = handle()->Width();
+ input_def.dimensions.c = handle()->Channels();
+ input_def.object_def.data_layout = ToDataLayout(handle()->GetStorageType());
+ input_def.object_def.data_type = handle()->GetDataType();
+ input_def.object_def.object_type = ToObjectType(handle()->GetStorageType());
+ input_def.object_def.user_provided = false;
+
+ TensorObjectDef tmp_def;
+ tmp_def.dimensions.b = handle()->Batch();
+ tmp_def.dimensions.h = handle()->Height();
+ tmp_def.dimensions.w = handle()->Width();
+ tmp_def.dimensions.c = handle()->Channels();
+ tmp_def.object_def.data_layout = DataLayout::BHWC;
+ tmp_def.object_def.data_type = DataType::FLOAT32;
+ tmp_def.object_def.object_type = ToObjectType(handle()->GetStorageType());
+ tmp_def.object_def.user_provided = true;
+
+ auto dims = tmp_def.dimensions;
+ const BHWC shape(dims.b, dims.h, dims.w, dims.c);
+ const TensorDescriptor desc{
+ tmp_def.object_def.data_type,
+ ToTensorStorageType(tmp_def.object_def.object_type, tmp_def.object_def.data_layout),
+ Layout::BHWC};
+ if (!AllocateTensorMemory(_environment->context(), shape, desc, &_cl_memory).ok())
+ {
+ throw std::runtime_error("AllocateTensorMemory error.");
+ }
+ TensorObject tmp_obj;
+ if (tmp_def.object_def.object_type == ObjectType::OPENCL_TEXTURE)
+ {
+ tmp_obj = OpenClTexture{_cl_memory.memory()};
+ }
+ else
+ {
+ tmp_obj = OpenClBuffer{_cl_memory.memory()};
+ }
+ TensorObjectDef output_def = input_def;
+ output_def.dimensions.b = handle()->Batch();
+ output_def.dimensions.h = handle()->Height();
+ output_def.dimensions.w = handle()->Width();
+ output_def.dimensions.c = handle()->Channels();
+ output_def.object_def.data_layout = DataLayout::BHWC;
+ output_def.object_def.data_type = DataType::FLOAT32;
+ output_def.object_def.object_type = ObjectType::CPU_MEMORY;
+ output_def.object_def.user_provided = true;
+
+ _converter_builder = NewConverterBuilder(_environment.get());
+ if (!_converter_builder->MakeConverter(input_def, tmp_def, &_converter_bhwc).ok())
+ {
+ throw std::runtime_error("MakeConverter<_converter_bhwc> error.");
+ }
+ if (!_converter_builder->MakeConverter(tmp_def, output_def, &_converter_cpu).ok())
+ {
+ throw std::runtime_error("MakeConverter<_converter_cpu> error.");
+ }
+
+ if (!_converter_bhwc->Convert(input_obj, tmp_obj).ok())
+ {
+ throw std::runtime_error("[r] _converter_bhwc Convert error.");
+ }
+ if (!_converter_cpu->Convert(tmp_obj, output_obj).ok())
+ {
+ throw std::runtime_error("[r] _converter_cpu Convert error.");
+ }
+}
+
+} // namespace operand
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_GPU_CL_OPERAND_I_CL_TENSOR_H__
+#define __ONERT_BACKEND_GPU_CL_OPERAND_I_CL_TENSOR_H__
+
+#include <backend/ITensor.h>
+
+#include "open_cl/Api.h"
+#include "open_cl/Spi.h"
+#include "open_cl/ClCommandQueue.h"
+#include "open_cl/kernels/Converter.h"
+#include "open_cl/Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace gpu_cl
+{
+namespace operand
+{
+
+class ICLTensor : public ITensor
+{
+public:
+ ICLTensor() = default;
+ ICLTensor(const ICLTensor &) = delete;
+ ICLTensor &operator=(const ICLTensor &) = delete;
+ ICLTensor(ICLTensor &&) = default;
+ ICLTensor &operator=(ICLTensor &&) = default;
+
+ ICLTensor(size_t rank, ir::Shape shape, std::shared_ptr<Environment> environment)
+ : _rank{rank}, _shape{shape}, _environment(environment)
+ {
+ }
+
+public:
+ uint8_t *buffer() const final { return reinterpret_cast<uint8_t *>(handle()->GetMemoryPtr()); }
+ size_t total_size() const final { return _shape.num_elements() * sizeof(float); }
+ size_t calcOffset(const ir::Coordinates &coords) const final
+ {
+ // NYI
+ (void)coords;
+ return 0;
+ }
+ ir::Layout layout() const final { return ir::Layout::NHWC; }
+ ir::DataType data_type() const final { return ir::DataType::FLOAT32; }
+ float data_scale() const override
+ {
+ throw std::runtime_error("ICLTensor::data_scale() is not supported.");
+ }
+ int32_t data_zero_point() const override
+ {
+ throw std::runtime_error("ICLTensor::data_zero_point() is not supported.");
+ }
+ const std::vector<float> &data_scales() const override
+ {
+ throw std::runtime_error("ICLTensor::data_scales() is not supported.");
+ }
+ const std::vector<int32_t> &data_zero_points() const override
+ {
+ throw std::runtime_error("ICLTensor::data_zero_points() is not supported.");
+ }
+ bool is_dynamic() const override { return false; }
+ ir::Shape getShape() const override { return _shape; }
+ bool has_padding() const override { return false; }
+ void access(const std::function<void(ITensor &tensor)> &fn) final;
+ bool needMemoryMap() const final { return true; }
+ void enqueueWriteBuffer(const void *ptr, bool blocking = true) final;
+ void enqueueReadBuffer(void *ptr, bool blocking = true) final;
+
+public:
+ virtual const Tensor *handle() const = 0;
+ virtual Tensor *handle() = 0;
+
+private:
+protected:
+ size_t _rank; // Actual rank (reflects extended rank)
+ ir::Shape _shape;
+ std::shared_ptr<Environment> _environment;
+ std::unique_ptr<TensorObjectConverterBuilder> _converter_builder;
+ CLMemory _cl_memory;
+ std::unique_ptr<TensorObjectConverter> _converter_cpu;
+ std::unique_ptr<TensorObjectConverter> _converter_bhwc;
+};
+
+} // namespace operand
+} // namespace gpu_cl
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_GPU_CL_OPERAND_I_CL_TENSOR_H__
std::unordered_map<ir::OperationIndex, std::string> index_to_backend;
};
+struct PartialGraphOptions
+{
+ std::unordered_map<ir::OperationIndex, ir::SubgraphIndex> index_to_graph;
+};
+
struct CompilerOptions
{
// GENERAL OPTIONS
bool he_profiling_mode; //< Whether HEScheduler profiling mode ON/OFF
bool disable_compile; //< Run with Interpreter if true, try compilation otherwise
bool fp16_enable; //< Whether fp16 mode ON/OFF
+ PartialGraphOptions partial_graph_options;
util::TracingCtx *tracing_ctx; //< Profiling information
};
*/
std::shared_ptr<exec::ExecutorMap> compile(void);
+ /**
+ * @brief Do compilation with the options
+ *
+ * @return std::vector<std::shared_ptr<exec::ExecutorMap>> Executors as a result of compilation
+ * for pipeline
+ */
+ std::vector<std::shared_ptr<exec::ExecutorMap>> compile(const char *package_file_path,
+ const char *map_file_path);
+
State state(void) const { return _state; }
CompilerOptions &options() { return _options; }
*/
void enableToFp16();
+ /**
+ * @brief Set backends from string-encoded mappings from operation index to backend type (cpu,
+ * acl_cl)
+ */
+ void set_backend_from_str(const char *backend_settings);
+
+ /**
+ * @brief Build the partial graphs to compile with original graph
+ */
+ bool buildPartialGraph(uint32_t num_graphs);
+
private:
void checkProfilerConditions();
std::shared_ptr<ir::Graph> &primary_subgraph() { return _subgraphs->at(ir::SubgraphIndex{0}); }
{
public:
LoweredGraph(const ir::Graph &graph, const compiler::CompilerOptions &options);
+ LoweredGraph(const ir::Graph &parent_graph, const ir::Graph &graph,
+ const compiler::CompilerOptions &options);
ir::Graph &graph() { return _graph; }
const ir::Graph &graph() const { return _graph; }
+ ir::Graph &parent_graph() { return _parent_graph; }
+ const ir::Graph &parent_graph() const { return _parent_graph; }
const compiler::GraphLowerInfo &lower_info() const { return _lower_info_map; }
compiler::GraphLowerInfo &lower_info() { return _lower_info_map; }
std::shared_ptr<ir::OperationIndexMap<int64_t>> indexed_ranks() { return _indexed_ranks; }
private:
ir::Graph _graph;
+ ir::Graph _parent_graph;
std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
compiler::GraphLowerInfo _lower_info_map;
ir::OperationIndexMap<bool> _has_dynamic_tensor_map;
#include "IODescription.h"
#include <thread>
+#include <deque>
+#include <semaphore.h>
namespace onert
{
*/
const ir::Graph &primary_subgraph() const { return primary_executor()->graph(); }
+ const ir::Graph &primary_parentgraph() const { return primary_executor()->parent_graph(); }
/**
* @brief Change input shape
* @param[in] index Input index
*/
void setInput(const ir::IOIndex &index, const void *buffer, size_t length,
ir::Layout layout = ir::Layout::NHWC);
+
/**
* @brief Set input data's information, especially to specify unknown dimensions on model
* build time.
ir::Shape getInputShape(ir::IOIndex ind) const;
ir::Shape getOutputShape(ir::IOIndex ind) const;
+ //
+ // Experimental API
+ //
+
+ // accessor
+ std::vector<
+ std::tuple<std::shared_ptr<onert::exec::Execution>, onert::ir::IOIndex, onert::ir::IOIndex>>
+ getNextExes()
+ {
+ return next_exes;
+ }
+ std::deque<std::pair<IODescription *, uint32_t>> *getAsyncIoDescs() { return &_async_io_descs; }
+ std::deque<std::vector<void *>> *getAsyncResults() { return &_async_results; }
+
+ /**
+ * @brief Push IO information between related executions into next_exes
+ * @param[in] next address of next execution
+ * @param[in] o_index Output index of current execution (it will be the input of next execution)
+ * @param[in] i_index Input index of next execution
+ */
+ void pushNextExe(std::shared_ptr<onert::exec::Execution> next, onert::ir::IOIndex o_index,
+ onert::ir::IOIndex i_index)
+ {
+ next_exes.push_back({next, o_index, i_index});
+ }
+
+ /**
+ * @brief Create New IODescription instance for new inputs outputs
+ * @param[in] index instance count number
+ */
+ void createNewAsyncDesc(uint32_t count = 0);
+
+ /**
+ * @brief Set async input data's information
+ * @param[in] index Input index
+ * @param[in] buffer Input data's buffer pointer
+ * @param[in] length Input data's length
+ * @param[in] layout Input data's data format
+ */
+ void executeAsyncInput(const ir::IOIndex &index, const void *buffer, size_t length,
+ ir::Layout layout = ir::Layout::NHWC);
+
+ /**
+ * @brief Set async output data's information
+ * @param[in] index Output index
+ * @param[in] buffer Output data's buffer pointer
+ * @param[in] length Output data's length
+ * @param[in] layout Output data's data format
+ */
+ void executeAsyncOutput(const ir::IOIndex &index, void *buffer, size_t length,
+ ir::Layout layout = ir::Layout::NHWC);
+
+ /**
+ * @brief Async execution
+ * @note It should be called after setting input and output buffer
+ */
+ void AsyncExecute();
+
+ /**
+ * @brief Set finish
+ */
+ void setFinish();
+
+ /**
+ * @brief Check if input queue is empty
+ * @return @c true if queue is empty, otherwise @c false
+ */
+ bool isEmptyQueue();
+
+ /**
+ * @brief Wait semaphore to prevent race condition
+ */
+ void asyncIoDescSemWait();
+
+ /**
+ * @brief Post semaphore to prevent race condition
+ */
+ void asyncIoDescSemPost();
+
+ /**
+ * @brief Inference
+ * @note this function provided to the thread for pipelining
+ */
+ void runInference();
+
+ /**
+ * @brief Check if stop_wait is true
+ * @return @c true if stop_wait is true, otherwise @c false
+ */
+ bool stopWait(void) const;
+
+ /**
+ * @brief Set stop_wait to terminate consumer thread
+ */
+ void sholudStop();
+
private:
const std::unique_ptr<IExecutor> &primary_executor() const
{
private:
const std::shared_ptr<ExecutorMap> _executors;
IODescription _io_desc;
+ std::deque<std::pair<IODescription *, uint32_t>> _async_io_descs;
+ sem_t _async_io_descs_sem;
+ std::deque<std::vector<void *>> _async_results;
+ std::vector<
+ std::tuple<std::shared_ptr<onert::exec::Execution>, onert::ir::IOIndex, onert::ir::IOIndex>>
+ next_exes;
std::unique_ptr<std::thread> _exec_thread;
bool finished{false};
+ bool stop_wait{false};
};
} // namespace exec
virtual const ir::Graph &graph() = 0;
/**
+ * @brief Returns parent graph object
+ *
+ * @return Graph object
+ */
+ virtual const ir::Graph &parent_graph() = 0;
+
+ /**
* @brief Set an ordering on operations
* @param[in] ranks The table encoding the ordering
*/
#include <vector>
#include <unordered_map>
+#include <semaphore.h>
#include "ir/OperandInfo.h"
#include "ir/Index.h"
void removeOperand(const OperandIndex &ind) { _operands.remove(ind); }
void setLayout(Layout layout) { _layout = layout; }
void setSubgraphs(const std::shared_ptr<Subgraphs> &subgs) { _subgraphs = subgs; }
+ void setPartialgraphs(const std::shared_ptr<Subgraphs> &partialgraphs)
+ {
+ _partialgraphs = partialgraphs;
+ }
+ void
+ setTensorName(std::shared_ptr<std::unordered_map<ir::OperandIndex, std::string>> &tensor_names)
+ {
+ _tensor_names = tensor_names;
+ }
private:
bool checkOperandsForOperation(const Operation &operation);
const std::shared_ptr<Subgraphs> &subgraphs() const { return _subgraphs; }
std::shared_ptr<Subgraphs> &subgraphs() { return _subgraphs; }
Layout layout() const { return _layout; }
+ std::shared_ptr<Subgraphs> &partialgraphs() { return _partialgraphs; }
+ std::shared_ptr<std::unordered_map<ir::OperandIndex, std::string>> &tensor_names()
+ {
+ return _tensor_names;
+ }
+ std::unordered_map<std::string, IOIndex>::iterator _name_to_input_begin()
+ {
+ return _name_to_input.begin();
+ }
+ std::unordered_map<std::string, IOIndex>::iterator _name_to_input_end()
+ {
+ return _name_to_input.end();
+ }
+ std::unordered_map<std::string, IOIndex>::iterator _name_to_output_begin()
+ {
+ return _name_to_output.begin();
+ }
+ std::unordered_map<std::string, IOIndex>::iterator _name_to_output_end()
+ {
+ return _name_to_output.end();
+ }
+ void input_sort() { _inputs.sort(); }
+ void output_sort() { _outputs.sort(); }
// Topological sort
public:
std::shared_ptr<Subgraphs> _subgraphs;
// TFLite and circle's default layout is NHWC;
Layout _layout{Layout::NHWC};
+
+ // Partial Graphs
+ std::shared_ptr<ir::Subgraphs> _partialgraphs;
+ std::shared_ptr<std::unordered_map<ir::OperandIndex, std::string>> _tensor_names;
};
} // namespace ir
#include <initializer_list>
#include <vector>
+#include <algorithm>
#include "ir/Index.h"
void append(const OperandIndex &index) { _vec.emplace_back(index); }
void append(const OperandIndexSequence &l) { _vec.insert(_vec.end(), l.begin(), l.end()); }
+ void sort()
+ {
+ std::sort(_vec.begin(), _vec.end(),
+ [](const auto &lhs, const auto &rhs) { return lhs.value() < rhs.value(); });
+ }
+
public:
uint32_t size() const { return static_cast<uint32_t>(_vec.size()); }
const OperandIndex &at(IOIndex set_index) const { return _vec.at(set_index.value()); }
enum class ElementwiseBinaryType
{
+ FLOOR_DIV,
LOGICAL_AND,
LOGICAL_OR,
MAX,
// Name | Type | Default
CONFIG(GRAPH_DOT_DUMP , int , "0")
-CONFIG(BACKENDS , std::string , "cpu;acl_cl;acl_neon;ruy;xnnpack;bcq") // FIXME Remove bcq
+CONFIG(BACKENDS , std::string , "cpu;acl_cl;acl_neon;ruy;xnnpack;gpu_cl;bcq") // FIXME Remove bcq
CONFIG(OP_BACKEND_ALLOPS , std::string , "")
CONFIG(OP_BACKEND_MAP , std::string , "")
CONFIG(DISABLE_COMPILE , bool , "0")
#include "util/ConfigSource.h"
#include "util/logging.h"
#include "ir/OperationDumper.h"
+#include "ir/OperationCloner.h"
#include "misc/string_helpers.h"
+#include "json/json.h"
namespace
{
void Compiler::enableToFp16() { _options.fp16_enable = true; }
+void Compiler::set_backend_from_str(const char *backend_settings)
+{
+ // Backend for all
+ auto &ms_options = _options.manual_scheduler_options;
+ auto key_val_list = nnfw::misc::split(backend_settings, ';');
+ for (const auto &key_val_str : key_val_list)
+ {
+ if (key_val_str.empty())
+ {
+ continue;
+ }
+
+ auto key_val = nnfw::misc::split(key_val_str, '=');
+ const auto &key_str = key_val.at(0);
+ const auto &val = key_val.at(1);
+ auto key = static_cast<uint32_t>(std::stoi(key_str));
+ ms_options.index_to_backend.emplace(ir::OperationIndex{key}, val);
+ }
+}
+
void Compiler::checkProfilerConditions()
{
if (!_options.he_scheduler)
throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
}
+bool Compiler::buildPartialGraph(uint32_t num_graphs)
+{
+ if (_subgraphs->count() > 1)
+ return false;
+
+ auto partialgraphs = std::make_shared<ir::Subgraphs>();
+
+ for (uint32_t idx = 0; idx < num_graphs; idx++)
+ {
+ auto partialgraph = std::make_unique<ir::Graph>();
+ partialgraphs->push(ir::SubgraphIndex{idx}, std::move(partialgraph));
+ }
+ _subgraphs->primary()->setPartialgraphs(partialgraphs);
+
+ auto partial_graph = primary_subgraph()->partialgraphs();
+
+ primary_subgraph()->operands().iterate(
+ [&](const ir::OperandIndex &operand_index, const ir::Operand &operand) {
+ auto use_operations = operand.getUses();
+
+ for (auto use_operation : use_operations)
+ {
+ auto graph_index = _options.partial_graph_options.index_to_graph.find(use_operation);
+ if (graph_index == _options.partial_graph_options.index_to_graph.end())
+ {
+ throw std::runtime_error("Invalid Partition Map");
+ }
+ auto partition = partial_graph->at(graph_index->second);
+
+ if (partition->operands().exist(operand_index))
+ {
+ continue;
+ }
+
+ auto new_operand = std::make_unique<ir::Operand>(operand);
+ new_operand->clearDefUse();
+ auto new_operand_ind = partition->addOperand(operand_index, std::move(new_operand));
+ UNUSED_RELEASE(new_operand_ind);
+ assert(new_operand_ind == operand_index);
+ }
+ });
+
+ primary_subgraph()->operations().iterate(
+ [&](const ir::OperationIndex &operation_index, const ir::Operation &operation) {
+ auto graph_index = _options.partial_graph_options.index_to_graph.find(operation_index);
+ if (graph_index == _options.partial_graph_options.index_to_graph.end())
+ {
+ throw std::runtime_error("Invalid Partition Map");
+ }
+ auto partition = partial_graph->at(graph_index->second);
+
+ auto operand_io = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED |
+ ir::Remove::UNDEFINED;
+ for (auto operand_index : operand_io)
+ {
+ if (partition->operands().exist(operand_index))
+ continue;
+
+ const auto &operand = primary_subgraph()->operands().at(operand_index);
+
+ auto new_operand = std::make_unique<ir::Operand>(operand);
+ new_operand->clearDefUse();
+
+ auto new_operand_index = partition->addOperand(operand_index, std::move(new_operand));
+ UNUSED_RELEASE(new_operand_index);
+ assert(new_operand_index == operand_index);
+ }
+
+ auto new_operation_index = partition->addOperation(operation_index, clone(operation));
+ UNUSED_RELEASE(new_operation_index);
+ assert(new_operation_index == operation_index);
+ });
+
+ for (uint32_t idx = 0; idx < partial_graph->count(); idx++)
+ {
+ auto partition = partial_graph->at(ir::SubgraphIndex{idx});
+
+ partition->operands().iterate([&](const ir::OperandIndex &operand_index,
+ const ir::Operand &operand) {
+ if (primary_subgraph()->getInputs().contains(operand_index) ||
+ (!operand.getDef().valid() && !operand.isConstant()))
+ {
+ partition->addInput(operand_index, primary_subgraph()->tensor_names()->at(operand_index));
+ }
+ if (primary_subgraph()->getOutputs().contains(operand_index) || operand.getUses().size() == 0)
+ {
+ partition->addOutput(operand_index, primary_subgraph()->tensor_names()->at(operand_index));
+ }
+
+ if (primary_subgraph()->operands().at(operand_index).getUses().size() > 1 &&
+ !primary_subgraph()->operands().at(operand_index).isConstant() &&
+ !partition->getInputs().contains(operand_index))
+ {
+ auto use_operations = primary_subgraph()->operands().at(operand_index).getUses();
+ auto iter = use_operations.begin();
+ ir::SubgraphIndex graph_index =
+ _options.partial_graph_options.index_to_graph.find(*iter++)->second;
+ while (iter != use_operations.end())
+ {
+ if (graph_index != _options.partial_graph_options.index_to_graph.find(*iter)->second &&
+ !partition->getOutputs().contains(operand_index))
+ {
+ partition->addOutput(operand_index,
+ primary_subgraph()->tensor_names()->at(operand_index));
+ }
+ iter++;
+ }
+ }
+ });
+
+ partition->verify();
+
+ bool same = true;
+ if (partition->getInputs().size() == primary_subgraph()->getInputs().size())
+ {
+ for (auto iter = partition->getInputs().begin(); iter != partition->getInputs().end(); ++iter)
+ {
+ if (!primary_subgraph()->getInputs().contains(*iter))
+ {
+ same = false;
+ break;
+ }
+ }
+ if (same == true)
+ {
+ partition->getInputs() = primary_subgraph()->getInputs();
+ }
+ else
+ {
+ partition->input_sort();
+ }
+ }
+
+ same = true;
+ if (partition->getOutputs().size() == primary_subgraph()->getOutputs().size())
+ {
+ for (auto iter = partition->getOutputs().begin(); iter != partition->getOutputs().end();
+ ++iter)
+ {
+ if (!primary_subgraph()->getOutputs().contains(*iter))
+ {
+ same = false;
+ break;
+ }
+ }
+ if (same == true)
+ {
+ partition->getOutputs() = primary_subgraph()->getOutputs();
+ }
+ else
+ {
+ partition->output_sort();
+ }
+ }
+ }
+ return true;
+}
+
std::shared_ptr<exec::ExecutorMap> Compiler::compile(void)
{
// Set control flow backend for control flow operators
return executors;
}
+std::vector<std::shared_ptr<exec::ExecutorMap>> Compiler::compile(const char *package_file_path,
+ const char *map_file_path)
+{
+ std::vector<std::shared_ptr<exec::ExecutorMap>> executors;
+ auto executor_map = std::make_shared<exec::ExecutorMap>();
+
+ std::string package_path(package_file_path);
+ std::string partition_map_file;
+
+ if (map_file_path)
+ {
+ partition_map_file = map_file_path;
+ }
+ else
+ {
+ partition_map_file = package_path + "/partition_map.json";
+ }
+
+ std::ifstream pmfs(partition_map_file);
+ Json::Value root;
+ pmfs >> root;
+ const Json::Value &map = root["partition_map"];
+ const Json::Value &np = root["num_partitions"];
+
+ uint32_t num_graphs = 1;
+
+ if (pmfs.is_open())
+ {
+ num_graphs = np.asUInt();
+ for (uint32_t i = 0; i < (uint32_t)map.size(); ++i)
+ {
+ _options.partial_graph_options.index_to_graph[ir::OperationIndex{i}] =
+ ir::SubgraphIndex{map[i].asUInt()};
+ }
+ }
+ else
+ {
+ throw std::runtime_error("There is no partition map file");
+ }
+
+ if (!buildPartialGraph(num_graphs))
+ {
+ throw std::runtime_error("It doesn't support in case there are subgraphs");
+ }
+
+ // Set control flow backend for control flow operators
+ {
+ auto &builtin_id = backend::builtin::Config::ID;
+ _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
+ _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
+ _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
+ }
+
+ // FIXME This is a workaround for bcq operations, should remove it
+ {
+ _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
+ _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+ }
+
+ // It doesn't support tracing in case of partial graph
+ {
+ _options.tracing_ctx = nullptr;
+ }
+
+ {
+ VERBOSE(Compiler) << std::boolalpha << "==== Compiler Options ====" << std::endl;
+ VERBOSE(Compiler) << "backend_list : "
+ << nnfw::misc::join(_options.backend_list.begin(),
+ _options.backend_list.end(), "/")
+ << std::endl;
+ VERBOSE(Compiler) << "trace_filepath : " << _options.trace_filepath << std::endl;
+ VERBOSE(Compiler) << "graph_dump_level : " << _options.graph_dump_level << std::endl;
+ VERBOSE(Compiler) << "executor : " << _options.executor << std::endl;
+ VERBOSE(Compiler) << "manual backend_for_all : "
+ << _options.manual_scheduler_options.backend_for_all << std::endl;
+ VERBOSE(Compiler) << "manual_scheduler_options : "
+ << getOpBackends(_options.manual_scheduler_options.opcode_to_backend)
+ << std::endl;
+ VERBOSE(Compiler) << "he_scheduler : " << _options.he_scheduler << std::endl;
+ VERBOSE(Compiler) << "he_profiling_mode : " << _options.he_profiling_mode << std::endl;
+ VERBOSE(Compiler) << "disable_compile : " << _options.disable_compile << std::endl;
+ VERBOSE(Compiler) << "fp16_enable : " << _options.fp16_enable << std::endl
+ << std::noboolalpha;
+ }
+
+ _subgraphs->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ // Mandatory passes
+ auto part = subg.partialgraphs();
+ part->iterate([&](const ir::SubgraphIndex &, ir::Graph &partialgraph) {
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantOutputPass>(partialgraph))
+ .append(std::make_unique<pass::OddOutputPass>(partialgraph))
+ .run();
+
+ // Optimizations
+ pass::PassRunner{}
+ .append(std::make_unique<pass::UnusedOperandEliminationPass>(partialgraph))
+ .run();
+ });
+ });
+
+ /***************************************************
+ * Prepare compilation phase
+ ***************************************************/
+
+ // Compilable check
+ // TODO: Support hybrid execution -
+ // execution between interpreter and compiled executor (including control flow)
+ if (_options.disable_compile)
+ {
+ _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
+ executor_map->emplace(index, std::make_unique<interp::InterpExecutor>(subg));
+ executors.push_back(executor_map);
+ });
+ _state = State::COMPILED;
+ return executors;
+ }
+
+ // Mode check
+ if (_options.he_profiling_mode)
+ checkProfilerConditions();
+
+ /***************************************************
+ * Backend independent analysis & optimization phase
+ ***************************************************/
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options.graph_dump_level);
+
+ // Lower: Assign backend
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>
+ lowered_partialgraphs;
+ _subgraphs->iterate([&](const ir::SubgraphIndex &, ir::Graph &subg) {
+ auto part = subg.partialgraphs();
+ part->iterate([&](const ir::SubgraphIndex &pindex, ir::Graph &partialgraph) {
+ onert::dumper::dot::DotDumper dot_dumper_part(partialgraph, dump_level);
+ dot_dumper_part.dump(nnfw::misc::str("before_lower_subg_partialgraph-", pindex.value()));
+
+ // // Lower: Assign backend
+ lowered_partialgraphs[pindex] =
+ std::make_unique<compiler::LoweredGraph>(subg, partialgraph, _options);
+ partialgraph.setSubgraphs(nullptr);
+ });
+ });
+
+ for (auto &pair : lowered_partialgraphs)
+ {
+
+ const auto &partialgraph_index = pair.first;
+ auto &lowered_partialgraph = pair.second;
+ onert::dumper::dot::DotDumper dot_dumper_lowered_part(lowered_partialgraph.get(), dump_level);
+ dot_dumper_lowered_part.dump("after_lower_subg_partialgraph-" +
+ std::to_string(partialgraph_index.value()));
+ }
+
+ // Partial Graph shape inference
+ for (auto &pair : lowered_partialgraphs)
+ {
+ const auto &partialgraph_index = pair.first;
+ auto &lowered_partialgraph = pair.second;
+ StaticShapeInferer partial_inferer(partialgraph_index, lowered_partialgraphs);
+ auto ordered_ops = lowered_partialgraph->graph().topolSortOperations();
+ for (auto op_ind : ordered_ops)
+ {
+ const auto &op = lowered_partialgraph->graph().operations().at(op_ind);
+ bool has_dynamic_tensor = partial_inferer.infer(op);
+ lowered_partialgraph->setHasDynamicTensor(op_ind, has_dynamic_tensor);
+ }
+ partial_inferer.dump();
+ }
+
+ // Shape validation
+ // TODO Move shape independent feature check from ShapeValidator to OperationValidator
+ // TODO Move ShapeValidator into shape inference
+ // - Check input tensor shape validation
+ // - Check parameter value validation which valid value is depend on input tensor shape
+ // - Output tensor shape validation check is needless because
+ // static/dynamic shape inferer will make valid output shape
+ for (auto &pair : lowered_partialgraphs)
+ {
+ auto &lowered_partialgraph = pair.second;
+ compiler::ShapeValidator{lowered_partialgraph->graph()}();
+ }
+
+ /*************************************************************
+ * Backend independent analysis & optimization phase finished
+ *************************************************************/
+ std::map<uint32_t, std::unique_ptr<compiler::LoweredGraph>> ordered;
+ for (auto &pair : lowered_partialgraphs)
+ {
+ // const auto &partialgraph_index = pair.first;
+ auto &lowered_partialgraph = pair.second;
+
+ ordered.insert(make_pair(pair.first.value(), std::move(lowered_partialgraph)));
+ }
+
+ for (auto &pair : ordered)
+ {
+ executor_map = std::make_shared<exec::ExecutorMap>();
+ const auto &partialgraph_index = ir::SubgraphIndex(pair.first);
+ auto &lowered_partialgraph = pair.second;
+ auto indexed_ranks = lowered_partialgraph->indexed_ranks();
+ ir::OperationDumper dumper("Executor generation of Subgraph " +
+ std::to_string(partialgraph_index.value()));
+ lowered_partialgraph->graph().operations().iterate(
+ [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_partialgraph), _options, executor_map)};
+ executor->setIndexedRanks(indexed_ranks);
+ executor_map->insert(std::make_pair(ir::SubgraphIndex{0}, std::move(executor)));
+ executors.push_back(executor_map);
+ }
+
+ _subgraphs.reset();
+ /********************************
+ * Code generation phase finished
+ ********************************/
+ _state = State::COMPILED;
+
+ return executors;
+}
+
} // namespace compiler
} // namespace onert
}
}
+LoweredGraph::LoweredGraph(const ir::Graph &parent_graph, const ir::Graph &graph,
+ const CompilerOptions &options)
+ : _graph{graph}, _parent_graph{parent_graph}
+{
+ // set tracing_ctx for copied graph
+ if (options.tracing_ctx)
+ {
+ auto subgraph_index = options.tracing_ctx->getSubgraphIndex(&graph);
+ options.tracing_ctx->setSubgraphIndex(&_graph, subgraph_index.value());
+ }
+
+ // Build backend contexts
+ auto &backend_manager = BackendManager::get();
+ // Create contexts for other backends
+ for (auto backend_str : options.backend_list)
+ {
+ backend_manager.loadBackend(backend_str);
+ auto backend = backend_manager.get(backend_str);
+
+ // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
+ // are not available on x64 or some other platforms. So this may be a workaround for x64 and
+ // we should change it back(throw if backend is not loaded) later.
+ if (!backend)
+ {
+ VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str << std::endl;
+ continue;
+ }
+ }
+ if (backend_manager.num_backends() == 0)
+ throw std::runtime_error{"No available backends loaded."};
+
+ // TODO Move "schedule" phase out of here
+ // Schedule
+ std::unique_ptr<BackendResolver> backend_resolver;
+ auto all_backends = backend_manager.getAll();
+ if (options.he_scheduler)
+ {
+ auto scheduler = HEScheduler(all_backends, options);
+ backend_resolver = scheduler.schedule(_graph);
+ _indexed_ranks = scheduler.getIndexedRanks();
+ }
+ else
+ {
+ auto scheduler = ManualScheduler(all_backends, options);
+ backend_resolver = scheduler.schedule(_graph);
+ }
+
+ makeLowerInfo(*backend_resolver);
+ VERBOSE(LoweredGraph) << "dump before mandatory passes" << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Mandatory passes - kind of legalization(?)
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantInsertionPass>(*this))
+ .append(std::make_unique<pass::ConstantLoweringPass>(*this))
+ .append(std::make_unique<pass::PermutationOperationPass>(*this))
+ .append(std::make_unique<pass::PermutationInsertionPass>(*this))
+ .run();
+
+ dumpLowerInfo();
+
+ // Optimization passes (optional)
+ pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run();
+
+ VERBOSE(LoweredGraph) << "Dump after all the passes" << std::endl;
+ for (auto operand : _graph.getInputs())
+ VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl;
+ for (auto operand : _graph.getOutputs())
+ VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Graph verifications
+ {
+ assert(ir::verifier::InputOutputChecker().verify(_graph));
+ assert(ir::verifier::DAGChecker().verify(_graph));
+ assert(ir::verifier::EdgeChecker().verify(_graph));
+ }
+}
+
void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver)
{
_graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
const auto &primary_subg = primary_subgraph();
_io_desc.inputs.resize(primary_subg.getInputs().size());
_io_desc.outputs.resize(primary_subg.getOutputs().size());
+ sem_init(&_async_io_descs_sem, 0, 1);
}
void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_shape)
_io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
}
+void Execution::createNewAsyncDesc(uint32_t count)
+{
+ IODescription *_async_io_desc = new IODescription;
+ _async_io_desc->inputs.resize(primary_subgraph().getInputs().size());
+ _async_io_desc->outputs.resize(primary_subgraph().getOutputs().size());
+
+ _async_io_descs.push_back({_async_io_desc, count});
+}
+
+void Execution::setFinish() { finished = true; }
+
+bool Execution::isEmptyQueue()
+{
+ asyncIoDescSemWait();
+ bool ret = _async_io_descs.empty();
+ if (!ret)
+ {
+ for (uint32_t idx = 0; idx < _async_io_descs.front().first->inputs.size(); idx++)
+ {
+ if (_async_io_descs.front().first->inputs.at(idx).get() == nullptr)
+ {
+ ret = true;
+ break;
+ }
+ }
+ }
+ asyncIoDescSemPost();
+ return ret;
+}
+
+void Execution::executeAsyncInput(const ir::IOIndex &index, const void *buffer, size_t length,
+ ir::Layout layout)
+{
+ const auto input_index = primary_subgraph().getInputs().at(index);
+ const auto info = primary_subgraph().operands().at(input_index).info();
+ IODescription *_async_io_desc = _async_io_descs.back().first;
+
+ {
+ auto input_shape_sig = _async_io_desc->dynamic_input_shapes.find(index);
+ auto size_required =
+ (input_shape_sig != _async_io_desc->dynamic_input_shapes.end())
+ ? input_shape_sig->second.num_elements() * onert::ir::sizeOfDataType(info.typeInfo().type())
+ : info.total_size();
+
+ if (length < size_required)
+ {
+ throw std::runtime_error{"Too small length"};
+ }
+ }
+ void *_buffer = (void *)malloc(length);
+ if (_buffer == NULL)
+ {
+ throw std::runtime_error{"malloc failed"};
+ }
+ memcpy(_buffer, buffer, length);
+
+ _async_io_desc->inputs.at(index.value()) =
+ std::make_unique<InputDesc>(info, _buffer, length, layout);
+}
+
+void Execution::executeAsyncOutput(const ir::IOIndex &index, void *buffer, size_t length,
+ ir::Layout layout)
+{
+ const auto output_index = primary_subgraph().getOutputs().at(index);
+ const auto info = primary_subgraph().operands().at(output_index).info();
+ IODescription *_async_io_desc = _async_io_descs.front().first;
+
+ if (length < info.total_size())
+ {
+ throw std::runtime_error{"Too small length"};
+ }
+
+ _async_io_desc->outputs.at(index.value()) =
+ std::make_unique<OutputDesc>(info, buffer, length, layout);
+}
+
// TODO Remove default parameter
void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, const ir::Shape &shape,
const void *buffer, size_t length, ir::Layout layout)
VERBOSE(Execution) << "Execution finished" << std::endl;
}
+void Execution::AsyncExecute()
+{
+ VERBOSE(Execution) << "Start Async execution" << std::endl;
+ if (_async_io_descs.empty())
+ {
+ VERBOSE(Execution) << "The input is not ready" << std::endl;
+ return;
+ }
+
+ primary_executor()->execute(*_async_io_descs.front().first);
+}
+
void Execution::startExecute()
{
VERBOSE(Execution) << "Create asynchronous execution thread" << std::endl;
return output_desc->info.shape();
}
+void Execution::asyncIoDescSemWait() { sem_wait(&_async_io_descs_sem); }
+
+void Execution::asyncIoDescSemPost() { sem_post(&_async_io_descs_sem); }
+
+void Execution::runInference()
+{
+ uint32_t inference_cnt;
+ uint32_t output_sz = primary_subgraph().getOutputs().size();
+ while (true)
+ {
+ if (isEmptyQueue())
+ {
+ if (isFinished())
+ {
+ if (!next_exes.empty())
+ {
+ for (uint32_t i = 0; i < next_exes.size(); i++)
+ {
+ std::get<0>(next_exes[i])->setFinish();
+ }
+ }
+ else
+ {
+ sholudStop();
+ }
+ break;
+ }
+ }
+ else
+ {
+ for (uint32_t i = 0; i < output_sz; i++)
+ {
+ auto opidx = primary_subgraph().getOutputs().at(i);
+ auto shape = primary_subgraph().operands().at(opidx).shape();
+ auto dtype = primary_subgraph().operands().at(opidx).typeInfo().type();
+ auto rank = shape.rank();
+ uint32_t tensor_size = 1;
+ for (int32_t j = 0; j < rank; j++)
+ {
+ tensor_size *= shape.dim(j);
+ }
+ if (dtype == onert::ir::DataType::FLOAT32 || dtype == onert::ir::DataType::INT32 ||
+ dtype == onert::ir::DataType::UINT32)
+ tensor_size *= 4;
+ else if (dtype == onert::ir::DataType::INT64)
+ tensor_size *= 8;
+ void *_buffer = (void *)malloc(tensor_size);
+ if (_buffer == NULL)
+ {
+ throw std::runtime_error{"malloc failed"};
+ }
+ executeAsyncOutput(onert::ir::IOIndex(i), _buffer, tensor_size);
+ }
+ AsyncExecute();
+
+ // set inputs of next execution
+ auto _io_desc = getAsyncIoDescs()->front().first;
+ inference_cnt = getAsyncIoDescs()->front().second;
+ getAsyncIoDescs()->pop_front();
+
+ for (uint32_t i = 0; i < next_exes.size(); i++)
+ {
+ auto next_exe = std::get<0>(next_exes[i]);
+ auto o_index = std::get<1>(next_exes[i]);
+ auto i_index = std::get<2>(next_exes[i]);
+
+ next_exe->asyncIoDescSemWait();
+ auto next_io_descs = next_exe->getAsyncIoDescs();
+ bool exist = false;
+ for (auto iter = next_io_descs->begin(); iter != next_io_descs->end(); iter++)
+ {
+ if (inference_cnt == iter->second)
+ {
+ exist = true;
+ }
+ }
+
+ if (!exist)
+ {
+ next_exe->createNewAsyncDesc(inference_cnt);
+ }
+ for (auto iter = next_io_descs->begin(); iter != next_io_descs->end(); iter++)
+ {
+ if (inference_cnt == iter->second)
+ {
+ const auto input_index = next_exe->primary_subgraph().getInputs().at(i_index.value());
+ const auto info = next_exe->primary_subgraph().operands().at(input_index).info();
+
+ size_t length = _io_desc->outputs[o_index.value()]->size;
+ void *_buffer = (void *)malloc(length);
+ if (_buffer == NULL)
+ {
+ throw std::runtime_error{"malloc failed"};
+ }
+ memcpy(_buffer, _io_desc->outputs[o_index.value()]->buffer, length);
+
+ iter->first->inputs.at(i_index.value()) = std::make_unique<onert::exec::InputDesc>(
+ info, _buffer, length, onert::ir::Layout::NHWC);
+ break;
+ }
+ }
+ next_exe->asyncIoDescSemPost();
+ }
+
+ if (next_exes.empty())
+ {
+ std::vector<void *> results;
+ for (uint32_t i = 0; i < _io_desc->outputs.size(); i++)
+ {
+ size_t length = _io_desc->outputs[i]->size;
+ void *_buffer = (void *)malloc(length);
+ if (_buffer == NULL)
+ {
+ throw std::runtime_error{"malloc failed"};
+ }
+ memcpy(_buffer, _io_desc->outputs[i]->buffer, length);
+ results.push_back(_buffer);
+ }
+ _async_results.push_back(results);
+ }
+
+ for (uint32_t i = 0; i < _io_desc->inputs.size(); i++)
+ {
+ auto p = _io_desc->inputs.at(i).release();
+ if (p)
+ {
+ free((void *)p->buffer);
+ delete p;
+ }
+ }
+ for (uint32_t i = 0; i < _io_desc->outputs.size(); i++)
+ {
+ auto p = _io_desc->outputs.at(i).release();
+ if (p)
+ {
+ free(p->buffer);
+ delete p;
+ }
+ }
+ delete _io_desc;
+ }
+ }
+}
+
+bool Execution::stopWait(void) const { return stop_wait; }
+
+void Execution::sholudStop() { stop_wait = true; }
+
} // namespace exec
} // namespace onert
backend::BackendContexts &&backend_contexts,
const compiler::TensorRegistries &tensor_regs,
const util::TracingCtx *tracing_ctx)
- : _lowered_graph{std::move(lowered_graph)},
- _backend_contexts{std::move(backend_contexts)}, _graph{_lowered_graph->graph()}, _mutex(),
+ : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
+ _graph{_lowered_graph->graph()}, _parent_graph{_lowered_graph->parent_graph()}, _mutex(),
_tracing_ctx(tracing_ctx)
{
auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
const ir::Graph &graph() final { return _graph; }
+ const ir::Graph &parent_graph() final { return _parent_graph; }
+
void execute(const IODescription &desc) final;
void execute(const std::vector<backend::IPortableTensor *> &inputs,
std::unique_ptr<compiler::LoweredGraph> _lowered_graph;
backend::BackendContexts _backend_contexts;
const ir::Graph &_graph;
+ const ir::Graph &_parent_graph;
std::vector<backend::builtin::IOTensor *> _input_tensors;
std::vector<backend::builtin::IOTensor *> _output_tensors;
std::mutex _mutex;
void LinearExecutor::executeImpl()
{
- auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph);
-
- _subject.notifySubgraphBegin(profiling_subg_index);
- for (auto &&code : _code)
+ if (_tracing_ctx)
{
- const auto backend = code.lower_info->backend();
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph);
+
+ _subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto &&code : _code)
+ {
+ const auto backend = code.lower_info->backend();
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
- ruy::profiler::ScopeLabel label(code.op->name());
+ ruy::profiler::ScopeLabel label(code.op->name());
#endif
- _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+ _subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+
+ auto &fn_seq = code.fn_seq;
+
+ fn_seq->initRunning();
+
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput();
+ fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
+ fn_seq->run();
- auto &fn_seq = code.fn_seq;
+ _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ _subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
+ {
+ for (auto &&code : _code)
+ {
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
- fn_seq->initRunning();
+ auto &fn_seq = code.fn_seq;
- bool handle_dynamic_tensor =
- _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput();
- fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
- fn_seq->run();
+ fn_seq->initRunning();
- _subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput();
+ fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
+ fn_seq->run();
+ }
}
- _subject.notifySubgraphEnd(profiling_subg_index);
}
} // namespace exec
* @return Graph object
*/
const ir::Graph &graph() final { return _graph; }
+
+ const ir::Graph &parent_graph() final
+ {
+ throw new std::runtime_error{"Interpreter does not support this function."};
+ }
void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>>) override{
// Not implemented
};
{
using ElementwiseBinaryType = onert::ir::operation::ElementwiseBinary::ElementwiseBinaryType;
static const std::unordered_map<ElementwiseBinaryType, std::string> name_map{
+ {ElementwiseBinaryType::FLOOR_DIV, std::string{"FloorDiv"}},
{ElementwiseBinaryType::LOGICAL_AND, std::string{"LogicalAnd"}},
{ElementwiseBinaryType::LOGICAL_OR, std::string{"LogicalOr"}},
{ElementwiseBinaryType::MAX, std::string{"Max"}},
* @param graph reference on subgraphs
*/
explicit BaseLoader(std::unique_ptr<ir::Subgraphs> &subgs)
- : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _subgraphs(subgs), _model{nullptr}
+ : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _subgraphs(subgs), _model{nullptr},
+ _tensor_names(std::make_shared<std::unordered_map<ir::OperandIndex, std::string>>())
{
_use_mmaped_data = util::getConfigBool(util::config::USE_MMAPED_DATA);
}
const Model *_model;
// Maps Tensor indices to onert Operands.
std::vector<ir::OperandIndex> _tensor_to_operand;
- std::unordered_map<ir::OperandIndex, std::string> _tensor_names;
+ std::shared_ptr<std::unordered_map<ir::OperandIndex, std::string>> _tensor_names;
// Verifier
std::unique_ptr<Verifier> _verifier;
// Boolean flag to use MMAPED_DATA
subg.setOperandValue(operand_index, std::move(data_obj));
}
- _tensor_names.emplace(operand_index, tensor->name()->str());
+ _tensor_names->emplace(operand_index, tensor->name()->str());
// Variable
if (tensor->is_variable())
case BuiltinOperator::BuiltinOperator_UNPACK:
loadUnpack(op, subg);
return;
+ case BuiltinOperator::BuiltinOperator_FLOOR_DIV:
+ loadElementwiseBinary(op, subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV);
+ return;
case BuiltinOperator::BuiltinOperator_MINIMUM:
loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MIN);
return;
for (const std::int32_t input_ind : *circle_subg->inputs())
{
subg->addInput(tensorIdxToOperandIdx(input_ind),
- _tensor_names.at(_tensor_to_operand[input_ind]));
+ _tensor_names->at(_tensor_to_operand[input_ind]));
}
// Set outputs
for (const std::int32_t output_ind : *circle_subg->outputs())
{
subg->addOutput(tensorIdxToOperandIdx(output_ind),
- _tensor_names.at(_tensor_to_operand[output_ind]));
+ _tensor_names->at(_tensor_to_operand[output_ind]));
}
// Create operations
for (const auto *op : *circle_subg->operators())
for (const std::int32_t input_ind : *tflite_subg->inputs())
{
subg->addInput(tensorIdxToOperandIdx(input_ind),
- _tensor_names.at(_tensor_to_operand[input_ind]));
+ _tensor_names->at(_tensor_to_operand[input_ind]));
}
// Set outputs
for (const std::int32_t output_ind : *tflite_subg->outputs())
{
subg->addOutput(tensorIdxToOperandIdx(output_ind),
- _tensor_names.at(_tensor_to_operand[output_ind]));
+ _tensor_names->at(_tensor_to_operand[output_ind]));
}
// Create operations
for (const auto *op : *tflite_subg->operators())
loadOperation(op, *subg);
}
+ subg->setTensorName(_tensor_names);
subg->verify();
return subg;
return RUN_ALL_TESTS();
}
+// FIX for onert: disable argument
+#if 0
void checkArgs(int argc, char** argv, int nextArg) {
if (nextArg != argc) {
std::cerr << "Unexpected argument: " << argv[nextArg] << std::endl;
exit(1);
}
}
+#endif
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
+ // FIX for onert: disable argument
+#if 0
if ((argc > 1) && std::isdigit(argv[1][0])) {
allowedPasses = std::stoull(argv[1]);
checkArgs(argc, argv, 2);
} else {
checkArgs(argc, argv, 1);
}
+#endif
#ifndef NNTEST_ONLY_PUBLIC_API
android::nn::initVLogMask();
target_compile_definitions(${RUNTIME_NNFW_API_TEST} PRIVATE TEST_XNNPACK_BACKEND)
endif(Xnnpack_FOUND)
+nnas_find_package(Opencl_Headers QUIET)
+if(Opencl_Headers_FOUND)
+ target_compile_definitions(${RUNTIME_NNFW_API_TEST} PRIVATE TEST_GPU_CL_BACKEND)
+endif(Opencl_Headers_FOUND)
+
set(RUNTIME_NNFW_API_TEST_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/src)
target_include_directories(${RUNTIME_NNFW_API_TEST} PRIVATE ${RUNTIME_NNFW_API_TEST_INCLUDE})
-target_link_libraries(${RUNTIME_NNFW_API_TEST} nnfw-dev)
+target_link_libraries(${RUNTIME_NNFW_API_TEST} nnfw-dev jsoncpp)
target_link_libraries(${RUNTIME_NNFW_API_TEST} gtest gmock)
target_link_libraries(${RUNTIME_NNFW_API_TEST} ${LIB_PTHREAD} dl)
target_link_libraries(${RUNTIME_NNFW_API_TEST} circle_schema)
# Install nnpackage test model (while)
set(NNPACKAGE_MODEL_DIR ${NNAS_PROJECT_SOURCE_DIR}/nnpackage/examples/v1.0.0/while_dynamic)
install(DIRECTORY ${NNPACKAGE_MODEL_DIR} DESTINATION ${NNPACKAGE_INSTALL_TARGET}/while_dynamic)
+
+# Install nnpackage test model (mobilenet)
+set(NNPACKAGE_MODEL_DIR ${NNAS_PROJECT_SOURCE_DIR}/runtime/contrib/TFLiteSharp/TFLiteTestApp/res/)
+set(NNPACKAGE_INSTALL_TARGET unittest_standalone/nnfw_api_gtest_models)
+
+install(DIRECTORY ${NNPACKAGE_MODEL_DIR} DESTINATION ${NNPACKAGE_INSTALL_TARGET}/mobilenet_v1_1.0_224)
+
0);
}
+uint32_t CircleGen::addOperatorFloorDiv(const OperatorParams ¶ms)
+{
+ return addOperatorWithOptions(params, circle::BuiltinOperator_FLOOR_DIV,
+ circle::BuiltinOptions_NONE, 0);
+}
+
uint32_t CircleGen::addOperatorL2Normalization(const OperatorParams ¶ms)
{
auto options = circle::CreateL2NormOptions(_fbb).Union();
return addOperatorWithOptions(params, reduce_op, circle::BuiltinOptions_ReducerOptions, options);
}
+uint32_t CircleGen::addOperatorRelu(const OperatorParams ¶ms)
+{
+ return addOperatorWithOptions(params, circle::BuiltinOperator_RELU, circle::BuiltinOptions_NONE,
+ 0);
+}
+
+uint32_t CircleGen::addOperatorRelu6(const OperatorParams ¶ms)
+{
+ return addOperatorWithOptions(params, circle::BuiltinOperator_RELU6, circle::BuiltinOptions_NONE,
+ 0);
+}
+
uint32_t CircleGen::addOperatorReshape(const OperatorParams ¶ms, const Shape *new_shape)
{
auto options = circle::CreateReshapeOptionsDirect(_fbb, new_shape).Union();
uint32_t addOperatorExpandDims(const OperatorParams ¶ms);
uint32_t addOperatorFill(const OperatorParams ¶ms);
uint32_t addOperatorFloor(const OperatorParams ¶ms);
+ uint32_t addOperatorFloorDiv(const OperatorParams ¶ms);
uint32_t addOperatorFullyConnected(const OperatorParams ¶ms,
circle::FullyConnectedOptionsWeightsFormat weights_format =
circle::FullyConnectedOptionsWeightsFormat_DEFAULT);
* @brief Create circle Reshape op
* the second param new_shape can be optional just like circle::CreateReshapeOptionsDirect
*/
+ uint32_t addOperatorRelu(const OperatorParams ¶ms);
+ uint32_t addOperatorRelu6(const OperatorParams ¶ms);
uint32_t addOperatorReshape(const OperatorParams ¶ms, const Shape *new_shape = nullptr);
uint32_t addOperatorResizeBilinear(const OperatorParams ¶ms, bool align_corners = false,
bool half_pixel_centers = false);
_backends.push_back(backend);
}
#endif
+#ifdef TEST_GPU_CL_BACKEND
+ if (backend == "gpu_cl")
+ {
+ _backends.push_back(backend);
+ }
+#endif
}
}
SUCCEED();
}
+TEST_F(GenModelTest, Reshape_with_shape_param_as_const_float)
+{
+ CircleGen cgen;
+ auto f32 = circle::TensorType::TensorType_FLOAT32;
+ int input = cgen.addTensor({{4}, f32});
+
+ std::vector<int32_t> new_shape_data{2, 2}; // const of value [2, 2]
+ uint32_t new_shape_buf = cgen.addBuffer(new_shape_data);
+ int new_shape = cgen.addTensor({{2}, f32, new_shape_buf});
+ int out = cgen.addTensor({{2, 2}, f32});
+
+ // reshape with new_shape param
+ cgen.addOperatorReshape({{input, new_shape}, {out}}, &new_shape_data);
+ cgen.setInputsAndOutputs({input}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>({{1, 2, 3, 4}}, {{1, 2, 3, 4}}));
+ _context->setBackends({"gpu_cl"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_Reshape_with_shape_param_as_const_float)
+{
+ // We will ses if Reshape with shape param can generate error during compilation if param is wrong
+ CircleGen cgen;
+ auto f32 = circle::TensorType::TensorType_FLOAT32;
+
+ int input = cgen.addTensor({{4}, f32});
+
+ std::vector<int32_t> wrong_new_shape_data{2, 3}; // not match with input shape
+ uint32_t new_shape_buf = cgen.addBuffer(wrong_new_shape_data);
+ int new_shape = cgen.addTensor({{2}, f32, new_shape_buf});
+
+ int out = cgen.addTensor({{2, 2}, f32});
+
+ cgen.addOperatorReshape({{input, new_shape}, {out}}, &wrong_new_shape_data);
+ cgen.setInputsAndOutputs({input}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>({{1, 2, 3, 4}}, {{1, 2, 3, 4}}));
+ _context->setBackends({"gpu_cl"});
+
+ _context->expectFailCompile();
+
+ SUCCEED();
+}
+
TEST_F(GenModelTest, Reshape_without_shape_param)
{
CircleGen cgen;
{
return _base_path + "/nnfw_api_gtest_models/" + package_name + "/" + package_name;
}
+
+std::string NNPackages::getModelAbsoluteFilePath(const char *package_name)
+{
+ return _base_path + "/nnfw_api_gtest_models/" + package_name + "/" + package_name + ".tflite";
+}
std::string getModelAbsolutePath(const char *package_name);
/**
+ * @brief Get the absolute of the model file to find
+ *
+ * @param package_name Package name
+ * @return std::string The absolute path of model file
+ */
+ std::string getModelAbsoluteFilePath(const char *package_name);
+
+ /**
* @brief Save the current executable's directory based on argv[0] and CWD
*
* @param argv0 0th command line argument of the current process
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "fixtures.h"
+#include "common.h"
+#include <fstream>
+#include <stdio.h>
+#include <json/json.h>
+#include <thread>
+
+void build_partition_map()
+{
+ Json::Value root;
+ Json::Value graphs(Json::arrayValue);
+ int num = 31;
+
+ for (int i = 0; i < num; i++)
+ {
+ if (i < 7)
+ graphs.append(Json::Value(0));
+ else
+ graphs.append(Json::Value(1));
+ }
+
+ root["partition_map"] = graphs;
+ root["num_partitions"] = 2;
+
+ Json::StyledWriter sw;
+ std::string jsonString = sw.write(root);
+
+ FILE *pFile = NULL;
+
+ pFile = fopen("./partition_map.json", "wt");
+ fwrite(jsonString.c_str(), jsonString.length(), 1, pFile);
+ fclose(pFile);
+}
+
+TEST_F(ValidationTestPipelineSession, create_pipeline_001)
+{
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+ SUCCEED();
+}
+
+TEST_F(ValidationTestPipelineSession, pipeline_session_test_model)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_set_available_backends(_session, "cpu"));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+ SUCCEED();
+}
+
+TEST_F(ValidationTestPipelineSession, prepare_pipeline_001)
+{
+ std::ifstream readFile("./partition_map.json");
+
+ if (readFile.good())
+ {
+ remove("./partition_map.json");
+ }
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ ASSERT_EQ(nnfw_prepare_pipeline(_session, "./partition_map.json"), NNFW_STATUS_ERROR);
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+}
+
+TEST_F(ValidationTestPipelineSession, prepare_pipeline_002)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+
+ SUCCEED();
+}
+
+TEST_F(ValidationTestPipelineSession, input_tensorinfo_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_input;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_input_tensorinfo(_session, 0, &t_input));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+
+ SUCCEED();
+}
+
+TEST_F(ValidationTestPipelineSession, output_tensorinfo_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_output;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_output_tensorinfo(_session, 0, &t_output));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+
+ SUCCEED();
+}
+
+TEST_F(ValidationTestPipelineSession, input_size_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ uint32_t input_num = -1;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_input_size(_session, &input_num));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ ASSERT_EQ(input_num, 1);
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, output_size_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ uint32_t output_num = -1;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_output_size(_session, &output_num));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ ASSERT_EQ(output_num, 1);
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, set_input_tensorinfo_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_input_original;
+ nnfw_tensorinfo t_input_after;
+ nnfw_tensorinfo t_input = {NNFW_TYPE_TENSOR_FLOAT32, 4, {1, 224, 224, 3}};
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ NNFW_ENSURE_SUCCESS(nnfw_input_tensorinfo(_session, 0, &t_input_original));
+ NNFW_ENSURE_SUCCESS(nnfw_set_input_tensorinfo(_session, 0, &t_input));
+ NNFW_ENSURE_SUCCESS(nnfw_input_tensorinfo(_session, 0, &t_input_after));
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ ASSERT_TRUE(tensorInfoEqual(t_input_original, t_input_after));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, input_output_tensorindex)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ uint32_t input_index = 100;
+ NNFW_ENSURE_SUCCESS(nnfw_input_tensorindex(_session, "input", &input_index));
+ ASSERT_EQ(input_index, 0);
+
+ uint32_t output_index = 100;
+ NNFW_ENSURE_SUCCESS(
+ nnfw_output_tensorindex(_session, "MobilenetV1/Predictions/Reshape_1", &output_index));
+ ASSERT_EQ(output_index, 0);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_create_pipeline_001)
+{
+ ASSERT_EQ(nnfw_create_session(nullptr), NNFW_STATUS_UNEXPECTED_NULL);
+}
+
+TEST_F(ValidationTestPipelineSession, neg_pipeline_session_model_load)
+{
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ ASSERT_EQ(nnfw_load_model_from_modelfile(
+ nullptr, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()),
+ NNFW_STATUS_UNEXPECTED_NULL);
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+}
+
+TEST_F(ValidationTestPipelineSession, neg_prepare_pipeline_001)
+{
+ ASSERT_EQ(nnfw_prepare_pipeline(nullptr, nullptr), NNFW_STATUS_UNEXPECTED_NULL);
+}
+
+TEST_F(ValidationTestPipelineSession, neg_set_in_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ float input_buf[1 * 224 * 224 * 3];
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_set_input(_session, 0, NNFW_TYPE_TENSOR_FLOAT32, input_buf, sizeof(input_buf)),
+ NNFW_STATUS_ERROR);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_set_out_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ float output_buf[1 * 1001];
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_set_output(_session, 0, NNFW_TYPE_TENSOR_FLOAT32, output_buf, sizeof(output_buf)),
+ NNFW_STATUS_ERROR);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_input_tensorinfo_pipeline_001)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_input;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_input_tensorinfo(nullptr, 0, &t_input), NNFW_STATUS_UNEXPECTED_NULL);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_input_tensorinfo_pipeline_002)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_input;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_input_tensorinfo(_session, 1, &t_input), NNFW_STATUS_ERROR);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_output_tensorinfo_pipeline_001)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_output;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_output_tensorinfo(nullptr, 0, &t_output), NNFW_STATUS_UNEXPECTED_NULL);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_output_tensorinfo_pipeline_002)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_output;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_output_tensorinfo(_session, 1, &t_output), NNFW_STATUS_ERROR);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_input_output_size_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ uint32_t input_num = -1;
+ uint32_t output_num = -1;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_input_size(nullptr, &input_num), NNFW_STATUS_UNEXPECTED_NULL);
+ ASSERT_EQ(input_num, -1);
+ ASSERT_EQ(nnfw_output_size(nullptr, &output_num), NNFW_STATUS_UNEXPECTED_NULL);
+ ASSERT_EQ(output_num, -1);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_set_input_tensorinfo_pipeline)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+ nnfw_tensorinfo t_input = {NNFW_TYPE_TENSOR_FLOAT32, 4, {1, 224, 224, 3}};
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ ASSERT_EQ(nnfw_set_input_tensorinfo(nullptr, 0, &t_input), NNFW_STATUS_UNEXPECTED_NULL);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_input_output_tensorindex)
+{
+ std::vector<void *> dummy1;
+ std::vector<uint32_t> dummy2;
+
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ uint32_t input_index = 100;
+ ASSERT_EQ(nnfw_input_tensorindex(_session, "input1", &input_index), NNFW_STATUS_ERROR);
+ ASSERT_EQ(input_index, 100);
+
+ uint32_t output_index = 100;
+ ASSERT_EQ(nnfw_output_tensorindex(_session, "MobilenetV1/Predictions/Reshape_2", &output_index),
+ NNFW_STATUS_ERROR);
+ ASSERT_EQ(output_index, 100);
+
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, &dummy1, &dummy2));
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, neg_run_pipeline)
+{
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_set_available_backends(_session, "cpu"));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ auto producer = [this]() {
+ std::vector<void *> inputs;
+ std::vector<uint32_t> lengths;
+ inputs.clear();
+ lengths.clear();
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, (void *)&inputs, (void *)&lengths));
+ };
+
+ auto consumer = [this]() {
+ std::vector<void *> outputs;
+ ASSERT_EQ(nnfw_pop_pipeline_output(_session, (void *)&outputs), NNFW_STATUS_ERROR);
+ };
+
+ auto producer_thread = std::thread(producer);
+ auto consumer_thread = std::thread(consumer);
+
+ producer_thread.join();
+ consumer_thread.join();
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+}
+
+TEST_F(ValidationTestPipelineSession, run_pipeline)
+{
+ build_partition_map();
+
+ NNFW_ENSURE_SUCCESS(nnfw_create_session(&_session));
+ NNFW_ENSURE_SUCCESS(nnfw_load_model_from_modelfile(
+ _session, NNPackages::get().getModelAbsoluteFilePath("mobilenet_v1_1.0_224").c_str()));
+ NNFW_ENSURE_SUCCESS(nnfw_set_available_backends(_session, "cpu"));
+ NNFW_ENSURE_SUCCESS(nnfw_prepare_pipeline(_session, "./partition_map.json"));
+
+ auto producer = [this]() {
+ std::vector<void *> inputs;
+ std::vector<uint32_t> lengths;
+ inputs.clear();
+ lengths.clear();
+ NNFW_ENSURE_SUCCESS(nnfw_push_pipeline_input(_session, (void *)&inputs, (void *)&lengths));
+ };
+
+ auto producer_thread = std::thread(producer);
+
+ producer_thread.join();
+ NNFW_ENSURE_SUCCESS(nnfw_close_session(_session));
+
+ remove("./partition_map.json");
+
+ SUCCEED();
+}
}
};
+class ValidationTestPipelineSession : public ValidationTest
+{
+protected:
+ nnfw_session *_session = nullptr;
+};
+
#endif // __NNFW_API_TEST_FIXTURES_H__
_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->addTestCase(uniformTCD<float>({{1, 3, 2, 4}}, {{6, 7, 9, 8}}));
_context->addTestCase(uniformTCD<float>({{0, 1, 2, 3}}, {{5, 5, 9, 7}}));
- _context->setBackends({"acl_cl", "acl_neon", "cpu"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
SUCCEED();
}
_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->addTestCase(uniformTCD<float>({{1, 3, 2, 4}, {5, 4, 7, 4}}, {{6, 7, 9, 8}}));
- _context->setBackends({"acl_cl", "acl_neon", "cpu"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
SUCCEED();
}
_context = std::make_unique<GenModelTestContext>(cgen.finish());
_context->addTestCase(uniformTCD<float>({{1, 3, 2, 4}}, {{2, 6, 4, 8}}));
- _context->setBackends({"acl_cl", "acl_neon", "cpu"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
SUCCEED();
}
float scale;
int64_t zero_point;
} type = {circle::TensorType::TensorType_FLOAT32, 0.0f, 0};
- std::vector<std::string> backend = {"acl_cl", "acl_neon", "cpu"};
+ std::vector<std::string> backend = {"acl_cl", "acl_neon", "cpu", "gpu_cl"};
};
class AveragePool2DVariation : public GenModelTest,
{1, 2, 2, 1},
{1, 1, 1, 1},
{2, 2, 2, 2},
- {circle::TensorType::TensorType_UINT8, 1.2, 3}},
+ {circle::TensorType::TensorType_UINT8, 1.2, 3},
+ {"acl_cl", "acl_neon", "cpu"}},
// uint8_t data -large
AvgPool2DParam{
uniformTCD<uint8_t>({{std::vector<uint8_t>(18 * 36 * 2, 99)}}, {{99, 99, 99, 99}}),
{1, 18, 36, 2},
{1, 1, 2, 2},
{18, 18, 18, 18},
- {circle::TensorType::TensorType_UINT8, 1.2, 3}},
+ {circle::TensorType::TensorType_UINT8, 1.2, 3},
+ {"acl_cl", "acl_neon", "cpu"}},
// int8_t data
// TODO enable acl-cl, acl-neon backend
AvgPool2DParam{uniformTCD<int8_t>({{2, -6, 4, -8}}, {{-2}}),
cgen.setInputsAndOutputs({in}, {out});
_context = std::make_unique<GenModelTestContext>(cgen.finish());
- _context->setBackends({"acl_cl", "acl_neon", "cpu"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
_context->expectFailCompile();
SUCCEED();
cgen.setInputsAndOutputs({in}, {out});
_context = std::make_unique<GenModelTestContext>(cgen.finish());
- _context->setBackends({"acl_cl", "acl_neon", "cpu"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
_context->expectFailCompile();
SUCCEED();
_context->addTestCase(uniformTCD<float>(
{{4, 0, -5, 1, 0, 4, -1, 1, -1, -3, 3, -2, -4, 1, -2, 2, 4, -4, 2, 2, 0, 4, -1, -2, 4}},
{{47, -4, -25, 9, 10, 10, -13, 11, -14, -26, -12, 26, 20, 40, 1, 3, 11, 4}}));
- _context->setBackends({"acl_cl", "acl_neon", "cpu", "ruy", "xnnpack"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "ruy", "xnnpack", "gpu_cl"});
SUCCEED();
}
SUCCEED();
}
+TEST_F(GenModelTest, OneOp_DepthwiseConv2D_No_Multiplier)
+{
+ CircleGen cgen;
+ std::vector<float> weight_data{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ uint32_t weight_buf = cgen.addBuffer(weight_data);
+ std::vector<float> bias_data{0.5f, -0.5f};
+ uint32_t bias_buf = cgen.addBuffer(bias_data);
+ int in = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ int weight = cgen.addTensor({{1, 3, 1, 2}, circle::TensorType::TensorType_FLOAT32, weight_buf});
+ int bias = cgen.addTensor({{1, 1, 1, 2}, circle::TensorType::TensorType_FLOAT32, bias_buf});
+ int out = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorDepthwiseConv2D({{in, weight, bias}, {out}}, circle::Padding_SAME, 1, 1, 1,
+ circle::ActivationFunctionType_NONE);
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<float>({{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}},
+ {{16.5f, 27.5f, 28.5f, 43.5f, 8.5f, 15.5f, 12.5f, 23.5f}}));
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_DepthwiseConv2D_No_Multiplier_RELU6)
+{
+ CircleGen cgen;
+ std::vector<float> weight_data{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ uint32_t weight_buf = cgen.addBuffer(weight_data);
+ std::vector<float> bias_data{0.5f, -0.5f};
+ uint32_t bias_buf = cgen.addBuffer(bias_data);
+ int in = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ int weight = cgen.addTensor({{1, 3, 1, 2}, circle::TensorType::TensorType_FLOAT32, weight_buf});
+ int bias = cgen.addTensor({{1, 1, 1, 2}, circle::TensorType::TensorType_FLOAT32, bias_buf});
+ int out = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorDepthwiseConv2D({{in, weight, bias}, {out}}, circle::Padding_SAME, 1, 1, 1,
+ circle::ActivationFunctionType_RELU6);
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>({{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}},
+ {{6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}}));
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_DepthwiseConv2D_3x3)
+{
+ CircleGen cgen;
+ std::vector<float> weight_data{0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f,
+ 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f};
+ uint32_t weight_buf = cgen.addBuffer(weight_data);
+ std::vector<float> bias_data{0.0f, 0.0f};
+ uint32_t bias_buf = cgen.addBuffer(bias_data);
+ int in = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ int weight = cgen.addTensor({{1, 3, 3, 2}, circle::TensorType::TensorType_FLOAT32, weight_buf});
+ int bias = cgen.addTensor({{1, 1, 1, 2}, circle::TensorType::TensorType_FLOAT32, bias_buf});
+ int out = cgen.addTensor({{1, 2, 2, 2}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorDepthwiseConv2D({{in, weight, bias}, {out}}, circle::Padding_SAME, 1, 1, 1,
+ circle::ActivationFunctionType_NONE);
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<float>({{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}},
+ {{6.0f, 16.0f, 8.0f, 16.0f, 10.0f, 16.0f, 12.0f, 16.0f}}));
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "gpu_cl"});
+ SUCCEED();
+}
+
TEST_F(GenModelTest, OneOp_DepthwiseConv2D_Dilation)
{
CircleGen cgen;
_context->addTestCase(uniformTCD<float>({{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
{{4, 0, 3, 0, 0, 0, 2, 0, 1}}));
- _context->setBackends({"acl_cl", "acl_neon", "cpu", "xnnpack"});
+ _context->setBackends({"acl_cl", "acl_neon", "cpu", "xnnpack", "gpu_cl"});
SUCCEED();
}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GenModelTest.h"
+
+#include <memory>
+
+TEST_F(GenModelTest, OneOp_FloorDiv_VarToVar_Float)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int rhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<float>({{1.0, 2.0, -6.8, 24.2}, {1.0, 2.0, 3.0, 4.0}}, {{1.0, 1.0, -3.0, 6.0}}));
+ _context->setBackends({"cpu"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_FloorDiv_VarToVar_Float_Broadcast)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int rhs = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>({{1.0, 2.0, -6.8, 24.2}, {2.0}}, {{0.0, 1.0, -4, 12.0}}));
+ _context->setBackends({"cpu"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_FloorDiv_VarToVar_InvalidDivisor1)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int rhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_FloorDiv_Broadcast_InvalidDivisor1)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int rhs = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_FloorDiv_VarToVar_Int)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int rhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<int32_t>({{10, 20, -68, 242}, {1, 2, 3, 4}}, {{10, 10, -23, 60}}));
+ _context->setBackends({"cpu"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_FloorDiv_VarToVar_Int_Broadcast)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int rhs = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<int32_t>({{10, 20, -67, 242}, {2}}, {{5, 10, -34, 121}}));
+ _context->setBackends({"cpu"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_FloorDiv_VarToVar_InvalidDivisor2)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int rhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_FloorDiv_Broadcast_InvalidDivisor2)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ int rhs = cgen.addTensor({{1}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{1, 2, 2, 1}, circle::TensorType::TensorType_INT32});
+ cgen.addOperatorFloorDiv({{lhs, rhs}, {out}});
+ cgen.setInputsAndOutputs({lhs, rhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GenModelTest.h"
+
+TEST_F(GenModelTest, OneOp_Relu)
+{
+ CircleGen cgen;
+ int in = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorRelu({{in}, {out}});
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<float>({{0, 1.0, 3.0, 1.0, -1.0, -2.0f}}, {{0, 1.0, 3.0, 1.0, 0, 0}}));
+ _context->setBackends({"cpu", "gpu_cl"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_Relu_InvalidType)
+{
+ CircleGen cgen;
+ int in = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_UINT8});
+ int out = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorRelu({{in}, {out}});
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu", "gpu_cl"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
--- /dev/null
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GenModelTest.h"
+
+TEST_F(GenModelTest, OneOp_Relu6)
+{
+ CircleGen cgen;
+ int in = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorRelu6({{in}, {out}});
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(
+ uniformTCD<float>({{4, 7.0, 3.0, 8.0, -1.0, -2.0f}}, {{4, 6.0, 3.0, 6.0, 0, 0}}));
+ _context->setBackends({"cpu", "gpu_cl"});
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, neg_OneOp_Relu6_InvalidType)
+{
+ CircleGen cgen;
+ int in = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_UINT8});
+ int out = cgen.addTensor({{2, 3}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorRelu6({{in}, {out}});
+ cgen.setInputsAndOutputs({in}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->setBackends({"cpu", "gpu_cl"});
+ _context->expectFailModelLoad();
+
+ SUCCEED();
+}
SUCCEED();
}
+TEST_F(GenModelTest, neg_OneOp_Softmax_Invaild_Beta)
+{
+ CircleGen cgen;
+ int input = cgen.addTensor({{4, 1, 1, 1}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{4, 1, 1, 1}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorSoftmax({{input}, {out}}, 0.1);
+ cgen.setInputsAndOutputs({input}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>({{-1., 0., 1., 1.}}, {{-1., -1., -1., -1.}}));
+ _context->setBackends({"gpu_cl"});
+ _context->expectFailCompile();
+
+ SUCCEED();
+}
+
+TEST_F(GenModelTest, OneOp_Softmax)
+{
+ CircleGen cgen;
+ int lhs = cgen.addTensor({{1, 1, 1, 4}, circle::TensorType::TensorType_FLOAT32});
+ int out = cgen.addTensor({{1, 1, 1, 4}, circle::TensorType::TensorType_FLOAT32});
+ cgen.addOperatorSoftmax({{lhs}, {out}}, 1.0);
+ cgen.setInputsAndOutputs({lhs}, {out});
+
+ _context = std::make_unique<GenModelTestContext>(cgen.finish());
+ _context->addTestCase(uniformTCD<float>(
+ {{-1., 0., 1., 1.}},
+ {{0.054064586758613586, 0.14696279168128967, 0.39948627352714539, 0.39948627352714539}}));
+ _context->setBackends({"acl_cl", "cpu", "gpu_cl"});
+
+ SUCCEED();
+}
+
// Test with different value type
INSTANTIATE_TEST_CASE_P(
GenModelTest, SoftmaxVariation,
cast
concat
conv_2d
-custom
depthwise_conv_2d
div
embedding_lookup
cast
concat
conv_2d
-custom
depthwise_conv_2d
div
embedding_lookup
+++ /dev/null
-MODELFILE_NAME="custom_squared_diff_test.tflite"
* limitations under the License.
*/
-#include "tflite/ext/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tflite/interp/FlatBufferBuilder.h"
#include <tflite/Assert.h>
#include <tflite/InterpreterSession.h>
-#include <tflite/ext/kernels/register.h>
+#include <tflite/interp/FlatBufferBuilder.h>
#include <iostream>
#include <fstream>
// Read tflite model
StderrReporter error_reporter;
auto model = FlatBufferModel::BuildFromFile(tflite_file.c_str(), &error_reporter);
-
- BuiltinOpResolver resolver;
- InterpreterBuilder builder(*model, resolver);
+ auto builder = FlatBufferBuilder(*model);
std::unique_ptr<Interpreter> interpreter;
try
{
- TFLITE_ENSURE(builder(&interpreter));
+ interpreter = builder.build();
}
catch (const std::exception &e)
{
* limitations under the License.
*/
-#include "tflite/ext/kernels/register.h"
+#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "args.h"
throw std::runtime_error{"Cannot create model"};
}
- BuiltinOpResolver resolver;
+ tflite::ops::builtin::BuiltinOpResolver resolver;
InterpreterBuilder builder(*model, resolver);
TFLITE_ENSURE(builder(&interpreter))
interpreter->SetNumThreads(nnfw::misc::EnvVar("THREAD").asInt(1));
--- /dev/null
+#! /usr/bin/python
+import graph_analysis
+import json
+import logging
+import runtime_stats
+import os
+import sys
+import numpy as np
+from queue import LifoQueue
+
+
+class ModelInfo:
+ def __init__(self, modelfile, vertex_weights):
+ self._model_dir = os.path.dirname(modelfile)
+ self._dag = graph_analysis.generate_dag(modelfile)
+ self._ops = graph_analysis.get_model_ops(modelfile)
+ self._tensors = graph_analysis.get_model_tensors(modelfile)
+ self._vertex_weights = vertex_weights
+ """Return Directed Acyclic Graph (DAG)
+ """
+
+ def get_dag(self):
+ return self._dag
+ """Return list of model operations
+ """
+
+ def get_ops(self):
+ return self._ops
+ """Return list of model tensors
+ """
+
+ def get_tensors(self):
+ return self._tensors
+ """Return vertex weights representing execution times
+ of model operations
+ """
+
+ def get_vertex_weights(self):
+ return self._vertex_weights
+ """Return size (bytes) of tensor connecting operation indexes n1 and n2
+ """
+
+ def get_tensor_size(self, n1, n2):
+ tensor_id = set(self._ops[n1]['outputs']).intersection(
+ set(self._ops[n2]['inputs']))
+ assert (len(tensor_id) == 1)
+ idx = tensor_id.pop()
+ tensor = self._tensors[idx]['shape']
+ return np.prod(tensor) * tensor.itemsize
+
+ def get_model_path(self):
+ return self._model_dir
+
+
+class GraphPartition:
+ def __init__(self, K):
+ self._K = K
+ self._indx = np.zeros(K, dtype=int)
+ self._session_weights = np.zeros(K, dtype=int)
+ self._session_ids = []
+ logging.basicConfig(level=logging.DEBUG)
+ self._logger = logging.getLogger("Minmax")
+
+ def set_dbglevel(self, dbglevel):
+ logging.basicConfig(level=dbglevel)
+ self._logger = logging.getLogger("Minmax")
+ self._logger.setLevel(dbglevel)
+ """Generates a session graph out of the provided dag (Directed Acyclic Graph)
+ Each dag node is associated with a session id, stored under attribute _session_ids.
+ """
+
+ def generate_session_graph(self):
+ def get_session_ids(i, j):
+ cnt = 0
+ idx_i = -1
+ idx_j = -1
+ for idx in range(self._K):
+ if i in self._session_ids[idx]:
+ idx_i = idx
+ cnt += 1
+ if j in self._session_ids[idx]:
+ idx_j = idx
+ cnt += 1
+ if cnt == 2:
+ break
+ return (idx_i, idx_j)
+
+ dag = self._modelObj.get_dag()
+ n = dag.shape[0]
+ self._session_graph = np.zeros((self._K, self._K), dtype=int)
+ for i in range(n - 1):
+ for j in range(i + 1, n):
+ if dag[i][j] == 1:
+ idx1, idx2 = get_session_ids(i, j)
+ if idx1 == -1 or idx2 == -1:
+ self._logger.debug("Something wrong with session ids")
+ self._logger.debug(self._session_ids)
+ self._logger.debug(i, j)
+ sys.exit(-1)
+ if idx1 != idx2:
+ self._session_graph[idx1][idx2] = 1
+
+ for i in range(self._K - 1):
+ for j in range(i + 1, self._K):
+ if self._session_graph[i][j] == 1 and self._session_graph[j][i] == 1:
+ self._logger.error("Session graph has cycles (%d, %d)", i, j)
+ self._logger.error("Session %d: %s", i, self._session_ids[i])
+ self._logger.error("Session %d: %s", j, self._session_ids[j])
+ sys.exit(-1)
+ """Generate an initial partition of the topological ordering T, with the
+ help of provided vertex weights. This method will update _session_weights, that is,
+ the cumulative sum of vertex weights within a session/partition
+ """
+
+ def initial_partition(self, modelObj, T):
+ self._modelObj = modelObj
+ self._logger.debug("Topological order: %s", T)
+ vwgt = modelObj.get_vertex_weights()
+ sorted_vwgt = np.array([vwgt[i] for i in T])
+ self._logger.debug("sorted weights: %s", sorted_vwgt)
+ sum1 = 0
+ c_sorted_vw = []
+ for s in sorted_vwgt:
+ c_sorted_vw.append(sum1 + s)
+ sum1 += s
+
+ pivot = np.zeros(self._K - 1)
+ self._logger.debug("Cumulative sum weights: %s", c_sorted_vw)
+ for i in range(1, self._K):
+ pivot[i - 1] = round(i * c_sorted_vw[-1] / self._K)
+
+ for i in range(self._K - 1):
+ self._indx[i + 1] = np.argmin(abs(c_sorted_vw - pivot[i]))
+
+ sum_weights = []
+ for i in range(self._K):
+ if i == self._K - 1:
+ self._session_ids.append(np.array(T[self._indx[i]:]))
+ self._session_weights[i] = np.sum(sorted_vwgt[self._indx[i]:])
+ else:
+ self._session_ids.append(np.array(T[self._indx[i]:self._indx[i + 1]]))
+ self._session_weights[i] = np.sum(
+ sorted_vwgt[self._indx[i]:self._indx[i + 1]])
+ self.generate_session_graph()
+ """Print a summary that includes session graph, paritition info comprising node ids and their
+ cumulative vertex weights
+ """
+
+ def summarize(self, T):
+ self._logger.info(
+ "Session Graph:\n%s",
+ np.array2string(
+ self._session_graph, formatter={'int': lambda x: '{:>3}'.format(x)}))
+ for i in range(self._K):
+ self._logger.info("Partition %d : %s, sum weight = %s", i,
+ self._session_ids[i].tolist(), self._session_weights[i])
+ """Move nodes from session1 to session2 until the maximum of the two cumulative session weights are exceeded.
+ The parameters to this method include the session ids, list of vertex weights (vwgt), and the directed adjacency matrix (dag).
+ At the end of the move, the session ids per session, and the session weights are updated. As node movement may also affect the session
+ graph, the session graph is updated as well.
+ """
+
+ def move_nodes(self, session1, session2):
+ dag = self._modelObj.get_dag()
+ vwgt = self._modelObj.get_vertex_weights()
+
+ def session_edges(s1, s2, dag, forward_direction):
+ sdict = {}
+ if forward_direction == True:
+ for k in s1:
+ tmp_s = set(np.where(dag[k, :] == 1)[0]).difference(set(s1))
+ if len(tmp_s) > 0:
+ sdict[k] = list(tmp_s)
+ else:
+ for k in s2:
+ tmp_s = set(np.where(dag[k, :] == 1)[0]).intersection(set(s1))
+ if len(tmp_s) > 0:
+ for key in tmp_s:
+ sdict[key] = k
+ return sdict
+
+ move_success = False
+ if self._session_graph[session1][session2] == 1:
+ forward_direction = True
+ elif self._session_graph[session2][session1] == 1:
+ forward_direction = False
+ else:
+ self._logger.warning("Cannot move nodes between non-neighboring partitions")
+ return move_success
+
+ maxval = max(self._session_weights[session1], self._session_weights[session2])
+ improvement = True
+
+ marked = {}
+ while improvement == True:
+ s1 = self._session_ids[session1]
+ s2 = self._session_ids[session2]
+ sdict = session_edges(s1, s2, dag, forward_direction)
+
+ found_node = False
+ rnd_perm = np.random.permutation(list(sdict))
+ cnt = 0
+ while found_node == False and cnt < len(sdict):
+ rnd_key = rnd_perm[cnt]
+ marked[rnd_key] = True
+ found_node = True
+ if forward_direction == True:
+ for k in range(session2):
+ if len(
+ set(np.where(dag[rnd_key, :] == 1)[0]).intersection(
+ set(self._session_ids[k]))) > 0:
+ found_node = False
+ cnt += 1
+ break
+ else:
+ for k in range(session2 + 1, self._K):
+ if len(
+ set(np.where(dag[:, rnd_key] == 1)[0]).intersection(
+ set(self._session_ids[k]))) > 0:
+ found_node = False
+ cnt += 1
+ break
+ if found_node == True:
+ new_maxval = max(self._session_weights[session1] - vwgt[rnd_key],
+ self._session_weights[session2] + vwgt[rnd_key])
+ if new_maxval < maxval:
+ self._logger.info("[old maxval] %s --> %s [new maxval], id: %s",
+ maxval, new_maxval, rnd_key)
+ self._logger.debug("edges : %s", (sdict[rnd_key]))
+ if type(sdict[rnd_key]) is list:
+ rnd_val = np.random.choice(sdict[rnd_key])
+ else:
+ rnd_val = sdict[rnd_key]
+ if forward_direction == True:
+ if np.where(s2 == rnd_val)[0].size > 0:
+ s2 = np.insert(s2, np.where(s2 == rnd_val)[0], rnd_key)
+ else:
+ s2 = np.insert(s2, 0, rnd_key)
+ else:
+ if np.where(s2 == sdict[rnd_key])[0].size > 0:
+ s2 = np.insert(s2,
+ np.where(s2 == sdict[rnd_key])[0] + 1,
+ rnd_key)
+ else:
+ s2 = np.insert(s2, len(s2), rnd_key)
+ s1 = np.delete(s1, np.where(s1 == rnd_key))
+ del self._session_ids[session1]
+ self._session_ids.insert(session1, s1)
+ del self._session_ids[session2]
+ self._session_ids.insert(session2, s2)
+ self._session_weights[session1] -= vwgt[rnd_key]
+ self._session_weights[session2] += vwgt[rnd_key]
+ maxval = new_maxval
+ self.generate_session_graph()
+ move_success = True
+ else:
+ self._logger.warning("Move rejected, max value is greater")
+ improvement = False
+ else:
+ self._logger.warning(
+ "Candidate %d cannot be moved, as it violates acyclic constraint",
+ rnd_key)
+ improvement = False
+ return move_success
+ """Method to get the session with the maximum session weight, or cumulative exection time. This
+ session is then searched for its neighboring sessions. The neighbors are then ranked in increasing order
+ of their execution times, so that session moves can be performed in that order.
+ """
+
+ def get_bottleneck_info(self):
+ maxval = 0
+ ret_id = -1
+ for i in range(self._K):
+ if maxval < self._session_weights[i]:
+ maxval = self._session_weights[i]
+ ret_id = i
+ neighbor_dict = {}
+
+ for i in range(self._K):
+ if self._session_graph[ret_id][i] == 1 or self._session_graph[i][ret_id] == 1:
+ neighbor_dict[i] = self._session_weights[i]
+ sorted_neighbor_list = sorted(neighbor_dict.items(), key=lambda item: item[1])
+ self._logger.info("Bottleneck id --> %d, sorted neighbors --> %s", ret_id,
+ sorted_neighbor_list)
+ return ret_id, sorted_neighbor_list
+ """Get the cost and the partition id associated with the maximum value.
+ """
+
+ def get_maxPartitionCost(self):
+ dag = self._modelObj.get_dag()
+ maxval = 0
+ indx = -1
+ for i in range(self._K):
+ if self._session_weights[i] > maxval:
+ maxval = self._session_weights[i]
+ indx = i
+
+ def check_edges(dag, session1, session2):
+ e_cnt = 0
+ memory_overhead = 0
+ for s1 in self._session_ids[session1]:
+ for s2 in self._session_ids[session2]:
+ if dag[s1][s2] == 1:
+ e_cnt += 1
+ memory_overhead += self._modelObj.get_tensor_size(s1, s2)
+ elif dag[s2][s1] == 1:
+ self._logger.error("%d (session %d) connects to %d (session %d)",
+ s2, session2, s1, session1)
+ self._logger.error(self._session_graph)
+ sys.exit(-1)
+
+ assert (e_cnt > 0)
+ return e_cnt, memory_overhead
+
+ edge_cut = 0
+ total_memory_overhead = 0
+ for i in range(self._K - 1):
+ for j in range(i + 1, self._K):
+ if self._session_graph[i][j] == 1:
+ e_cnt, memory_overhead = check_edges(dag, i, j)
+ edge_cut += e_cnt
+ total_memory_overhead += memory_overhead
+ return indx, maxval, edge_cut, total_memory_overhead
+ """Get partition information.
+ """
+
+ def get_partitions(self):
+ return self._session_ids, self._session_weights, self._session_graph
+
+
+class GraphTopology:
+ def __init__(self, tflite_file, trace_file):
+ vertex_weights = runtime_stats.get_runtime_per_operation(trace_file)
+ self._modelObj = ModelInfo(tflite_file, vertex_weights)
+ self._dag = graph_analysis.generate_dag(tflite_file)
+ self._T = []
+ self._vwgt = np.array(vertex_weights)
+ logging.basicConfig(level=logging.INFO)
+ self._Graphlogger = logging.getLogger("Topology")
+
+ def set_dbglevel(self, dbglevel):
+ logging.basicConfig(level=dbglevel)
+ self._Graphlogger.setLevel(dbglevel)
+ """Perform Topological sort using the method outlined in https://arxiv.org/abs/1704.00705
+ """
+
+ def topological_sort(self):
+ del self._T
+ degree_matrix = np.copy(self._dag)
+ n = self._dag.shape[0]
+ S = []
+ T = LifoQueue(maxsize=n)
+ marked = {}
+
+ while T.qsize() < n:
+ indegree = np.sum(degree_matrix, axis=0)
+ candidates, = np.where(indegree == 0)
+ for i in candidates:
+ if i not in marked:
+ S.append(i)
+ np.random.seed()
+ random_pos = int(np.random.rand() * len(S))
+ random_node = S[random_pos]
+ marked[random_node] = True
+ T.put(random_node)
+ neighbors, = np.where(self._dag[random_node, :] == 1)
+ for i in neighbors:
+ degree_matrix[random_node][i] = 0
+ del S
+ S = []
+
+ self._T = list(T.queue)
+ """Create a partition instance and perform an initial split over the cumulative sum weights
+ """
+
+ def partition_graph(self, K):
+ self._partition = GraphPartition(K)
+ self._partition.initial_partition(self._modelObj, self._T)
+ """Move nodes between sessions id1 and id2
+ """
+
+ def partition_move(self, id1, id2):
+ return self._partition.move_nodes(id1, id2)
+ """Summarize partition information
+ """
+
+ def partition_summary(self):
+ self._partition.summarize(self._T)
+ """Optimize for minmax partition. At each iteration, find the bottlenecked partition, and shuffle nodes out of it
+ to its neighbor with the smallest weight. If the neighbor session cannot accomodate any more nodes (because the minmax criterion is violated),
+ then select the next neighbor with the smallest weight. Repeat iterations until no further improvement is possible.
+ """
+
+ def partition_minmax(self, oneshot=False):
+ improvement = True
+ while improvement == True:
+ improvement = False
+ bottleneck_id, neighbor_list = self._partition.get_bottleneck_info()
+ for neighbor, wgt in neighbor_list:
+ self._Graphlogger.debug("====Moving from session %d to session %d",
+ bottleneck_id, neighbor)
+ ret_success = self.partition_move(bottleneck_id, neighbor)
+ if ret_success == True:
+ improvement = True
+ self._Graphlogger.debug(
+ "====Successful move from session %d to session %d",
+ bottleneck_id, neighbor)
+ break
+ self._Graphlogger.debug("====Failed move from session %d to session %d",
+ bottleneck_id, neighbor)
+ if oneshot == True:
+ self.partition_summary()
+
+ return self._partition.get_maxPartitionCost()
+ """Perform MinMax partitioning over multiple runs, and pick the best solution.
+ """
+
+ def partition_minmax_multiple(self, K=3, nruns=100):
+ minval = np.inf
+ session_ids = []
+ session_weights = np.zeros(K, dtype=int)
+ edge_cut_best = 0
+ memory_overhead_best = 0
+ for run in range(nruns):
+ self._Graphlogger.debug("****Starting run %d", run)
+ self.topological_sort()
+ self.partition_graph(K)
+ indx, maxval, edge_cut, memory_overhead = self.partition_minmax()
+ if maxval < minval:
+ minval = maxval
+ edge_cut_best = edge_cut
+ memory_overhead_best = memory_overhead
+ session_ids, session_weights, session_graph = self._partition.get_partitions(
+ )
+ self._Graphlogger.debug("****Finished run %d", run)
+
+ self._Graphlogger.info("Done.. printing results")
+ self._Graphlogger.info("Session ids: ")
+ for i in range(K):
+ self._Graphlogger.info("Partition %d : %s, sum weight = %s", i,
+ session_ids[i].tolist(), session_weights[i])
+ self._Graphlogger.info(
+ "Session Graph:\n%s",
+ np.array2string(
+ session_graph, formatter={'int': lambda x: '{:>3}'.format(x)}))
+ self._Graphlogger.info("Edge cut: %d", edge_cut_best)
+ self._Graphlogger.info("Memory overhead (bytes): %d", memory_overhead_best)
+ output_data = {}
+ partition_map = np.zeros(self._dag.shape[0], dtype=int)
+ with open("".join([self._modelObj.get_model_path(), "/parition_map.json"]),
+ "w") as ofile:
+ for i in range(K):
+ for op_idx in session_ids[i]:
+ partition_map[op_idx] = i
+ output_data['partition_map'] = partition_map.tolist()
+ output_data['num_partitions'] = K
+ json.dump(output_data, ofile)
--- /dev/null
+# Heuristic Graph Partitioning
+
+This folder contains the necessary scripts to perform a a heuristic-based graph partitioning for machine learning models.
+
+The main contents of this folder are as follows:
+
+- [Python Files](#python-scripts)
+- [How to Run Partitioning Algorithm?](#how-to-partition-tflite-model)
+- [Example Script](#example-script-to-generate-partition-map)
+
+
+## Python Scripts
+The python scripts (**python3**) require an installation of Tensorflow 2.x package to retrieve TFLite model operations. Additionally, please ensure that the python `numpy` package has been installed beforehand. The scripts also import the following modules: `queue`, `json` and `argparse`, all of which should be available by default. If not, please install them either by `pip install <package>` or `sudo apt install python-<package>`.
+
+`Graph.py` is the main script that processes the model graph topology and implements the partitioning algorithm. Correspondingly, there are two classes within, namely `GraphTopology` and `GraphPartition`. `GraphTopology` has a container `GraphPartition` object within.
+
+`graph_analysis.py` is a helper module for translating TFLite models to graph data structures. `graph_analysis.py` is imported inside
+`Graph.py`.
+
+
+## How To Partition TFLite Model?
+To partition a TFLite model, simply `import Graph` at the outset. There are two ways to run the partitioning algorithm. If you prefer quick results without having to inspect the intermediate results, follow the steps below:
+
+### Quick Run For Final Result
+To get the partitioning result quickly, follow the steps below:
+
+1. Create a `GraphTopology` object as shown below:
+```
+In [70]: g = Graph.GraphTopology('inceptionV3.tflite', 'inceptionV3.chrome.json')
+```
+**Note**: Here, the argument `inceptionV3.chrome.json` is a single-execution trace of operation execution times, and is obtained using the Chrome Trace profiler.
+
+2. Run the **MinMax** partitioning algorithm over the topology. Specify the number of partitions (K) and the number of topological orderings (nruns) to evaluate before settling for the best result.
+```
+In [71]: g.partition_minmax_multiple(K=4, nruns=10)
+
+INFO:Topology:Session ids:
+INFO:Topology:Partition 0 : [0, 1, 2, 3, 4, 5, 6, 13, 7, 10, 14, 11, 12, 8, 9, 15, 22, 16, 23, 19, 17, 20, 21], sum weight = 292393
+INFO:Topology:Partition 1 : [18, 24, 26, 27, 28, 25, 29, 30, 31, 32, 33, 34, 38, 35, 36, 37, 39, 49, 44, 41, 45, 42, 50, 46, 43, 40, 47, 48, 51, 53, 56, 52], sum weight = 293959
+INFO:Topology:Partition 2 : [61, 57, 58, 54, 55, 62, 59, 60, 63, 73, 74, 65, 64, 68, 66, 69, 67, 70, 71, 72, 75, 76, 80, 77, 78, 79, 81, 82], sum weight = 290835
+INFO:Topology:Partition 3 : [83, 84, 85, 86, 87, 90, 94, 91, 88, 89, 92, 93, 95, 96, 101, 97, 106, 98, 102, 103, 104, 99, 105, 107, 100, 108, 109, 110, 111, 114, 119, 120, 112, 115, 116, 117, 113, 118, 121, 122, 123, 124, 125], sum weight = 293819
+INFO:Topology:Session Graph:
+[[ 0 1 0 0]
+ [ 0 0 1 0]
+ [ 0 0 0 1]
+ [ 0 0 0 0]]
+INFO:Topology:Edge cut: 12
+INFO:Topology:Memory overhead (bytes): 4366144
+
+In [72]:
+```
+
+### Detailed View
+For a detailed breakdown of the runtime steps, execute the function calls shown below:
+
+1. Create a `GraphTopology` object:
+```
+In [70]: g = Graph.GraphTopology('inceptionV3.tflite', 'inceptionV3.chrome.json')
+```
+
+2. Perform a topological sort
+```
+In [73]: g.topological_sort()
+```
+
+
+3. Partition the graph into K sub-graphs, using the topological order obtained above
+```
+In [74]: g.partition_graph(K=4)
+```
+
+4. View the execution time of each partition
+```
+In [75]: g.partition_summary()
+INFO:Minmax:Session Graph:
+[[ 0 1 0 0]
+ [ 0 0 1 0]
+ [ 0 0 0 1]
+ [ 0 0 0 0]]
+INFO:Minmax:Partition 0 : [0, 1, 2, 3, 4, 5, 6, 13, 8, 7, 14, 9, 10, 11, 12, 15, 22, 23, 17, 16, 18, 19], sum weight = 276635
+INFO:Minmax:Partition 1 : [20, 21, 24, 26, 28, 27, 29, 25, 30, 31, 32, 33, 38, 35, 36, 34, 37, 39, 40, 41, 44, 45, 46, 42, 43, 49, 50, 47, 48, 51, 52, 61], sum weight = 299334
+INFO:Minmax:Partition 2 : [56, 53, 54, 57, 58, 55, 59, 60, 62, 63, 73, 65, 66, 67, 68, 69, 74, 70, 64, 71, 72, 75, 80, 81, 77, 76, 78, 82, 85], sum weight = 291593
+INFO:Minmax:Partition 3 : [83, 86, 84, 79, 87, 94, 90, 91, 88, 92, 89, 93, 95, 106, 107, 96, 97, 101, 102, 98, 104, 103, 99, 100, 105, 108, 114, 109, 119, 120, 110, 112, 111, 113, 115, 117, 116, 118, 121, 122, 123, 124, 125], sum weight = 303444
+ ```
+
+ 5. Run a *OneShot* version of the partitioning algorithm
+```
+In [90]: indx, minmax, edge_cnt, memory_overhead = g.partition_minmax(oneshot=True)
+INFO:Minmax:Bottleneck id --> 3, sorted neighbors --> [(2, 291593)]
+DEBUG:Topology:====Moving from session 3 to session 2
+INFO:Minmax:[old maxval] 303444 --> 300754 [new maxval], id: 86
+WARNING:Minmax:Candidate 87 cannot be moved, as it violates acyclic constraint
+WARNING:Minmax:Move rejected, max value is greater
+DEBUG:Topology:====Successful move from session 3 to session 2
+INFO:Minmax:Bottleneck id --> 2, sorted neighbors --> [(3, 294283), (1, 299334)]
+DEBUG:Topology:====Moving from session 2 to session 3
+WARNING:Minmax:Move rejected, max value is greater
+DEBUG:Topology:====Failed move from session 2 to session 3
+DEBUG:Topology:====Moving from session 2 to session 1
+WARNING:Minmax:Move rejected, max value is greater
+DEBUG:Topology:====Failed move from session 2 to session 1
+INFO:Minmax:Session Graph:
+[[ 0 1 0 0]
+ [ 0 0 1 0]
+ [ 0 0 0 1]
+ [ 0 0 0 0]]
+INFO:Minmax:Partition 0 : [0, 1, 2, 3, 4, 5, 6, 13, 8, 7, 14, 9, 10, 11, 12, 15, 22, 23, 17, 16, 18, 19], sum weight = 276635
+INFO:Minmax:Partition 1 : [20, 21, 24, 26, 28, 27, 29, 25, 30, 31, 32, 33, 38, 35, 36, 34, 37, 39, 40, 41, 44, 45, 46, 42, 43, 49, 50, 47, 48, 51, 52, 61], sum weight = 299334
+INFO:Minmax:Partition 2 : [56, 53, 54, 57, 58, 55, 59, 60, 62, 63, 73, 65, 66, 67, 68, 69, 74, 70, 64, 71, 72, 75, 80, 81, 77, 76, 78, 82, 85, 86], sum weight = 300754
+INFO:Minmax:Partition 3 : [83, 84, 79, 87, 94, 90, 91, 88, 92, 89, 93, 95, 106, 107, 96, 97, 101, 102, 98, 104, 103, 99, 100, 105, 108, 114, 109, 119, 120, 110, 112, 111, 113, 115, 117, 116, 118, 121, 122, 123, 124, 125], sum weight = 294283
+```
+
+**Note** Please set debug levels in the script accordingly, for example, `g.set_dbglevel(logging.DEBUG)`.
+
+## Example Script To Generate Partition Map
+An example script `test_parition.py` is added to the folder. Please run `python3 test_partition.py --help` for details. The script parses the TFLite model file and the trace JSON as arguments, and creates a `partition_map.json` at the same location as the TFLite file. An output from running `test_partition.py` is shown below:
+
+```
+$ python3 test_partition.py /tmp/nnpackage/inception_v3/inception_v3.tflite /tmp/inceptionV3.chrome.json --num_parts=4
+
+...
+...
+INFO:Topology:Partition 0 : [0, 1, 2, 3, 4, 5, 6, 8, 13, 7, 9, 14, 10, 11, 12, 15, 19, 17, 16, 20, 21, 22, 23], sum weight = 292393
+INFO:Topology:Partition 1 : [18, 24, 28, 31, 32, 29, 30, 25, 26, 27, 33, 35, 34, 38, 36, 37, 39, 49, 40, 44, 45, 41, 50, 42, 43, 46, 47, 48, 51, 52, 56, 57], sum weight = 296611
+INFO:Topology:Partition 2 : [53, 61, 54, 58, 59, 60, 62, 55, 63, 68, 65, 73, 66, 64, 69, 67, 74, 70, 71, 72, 75, 80, 76, 81, 85, 82, 83, 77, 86], sum weight = 286608
+INFO:Topology:Partition 3 : [78, 79, 84, 87, 94, 90, 91, 88, 89, 92, 93, 95, 96, 106, 101, 102, 104, 107, 103, 97, 99, 105, 98, 100, 108, 114, 119, 120, 115, 117, 110, 116, 118, 112, 109, 111, 113, 121, 122, 123, 124, 125], sum weight = 295394
+INFO:Topology:Session Graph:
+[[ 0 1 0 0]
+ [ 0 0 1 0]
+ [ 0 0 0 1]
+ [ 0 0 0 0]]
+INFO:Topology:Edge cut: 12
+INFO:Topology:Memory overhead (bytes): 4403136
+
+$ cat /tmp/nnpackage/inception_v3/partition_map.json
+{"partition_map": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "num_partitions": 4}
+$
+```
--- /dev/null
+#! /usr/bin/python
+
+import sys
+import numpy as np
+import tensorflow as tf
+"""Return list of operations from TFLite model
+"""
+
+
+def get_model_ops(tflite_file):
+ intr = tf.lite.Interpreter(tflite_file)
+ intr.allocate_tensors()
+ ops = intr._get_ops_details()
+ return ops
+
+
+"""Return list of tensors from TFLite model
+"""
+
+
+def get_model_tensors(tflite_file):
+ intr = tf.lite.Interpreter(tflite_file)
+ intr.allocate_tensors()
+ tensors = intr.get_tensor_details()
+ return tensors
+
+
+"""Generate binary adjacency matrix from a tflite model. The adjacency matrix is symmetric and
+undirected.
+"""
+
+
+def generate_adj_matrix(tflite_file):
+ intr = tf.lite.Interpreter(tflite_file)
+ intr.allocate_tensors()
+ ops = intr._get_ops_details()
+ adj_mat = np.zeros((len(ops), len(ops)), dtype=int)
+ for i in range(len(ops) - 1):
+ for j in range(i + 1, len(ops)):
+ if i != j:
+ if len(set(ops[i]['outputs']).intersection(set(ops[j]['inputs']))) > 0:
+ adj_mat[i][j] = 1
+ adj_mat[j][i] = 1
+ if len(set(ops[i]['inputs']).intersection(set(ops[j]['outputs']))) > 0:
+ adj_mat[i][j] = 1
+ adj_mat[j][i] = 1
+ return adj_mat
+
+
+"""Generate directed acyclic graph (DAG) from a tflite model.
+"""
+
+
+def generate_dag(tflite_file):
+ intr = tf.lite.Interpreter(tflite_file)
+ intr.allocate_tensors()
+ ops = intr._get_ops_details()
+ adj_mat = np.zeros((len(ops), len(ops)), dtype=int)
+ for i in range(len(ops) - 1):
+ for j in range(i + 1, len(ops)):
+ if i != j:
+ if len(set(ops[i]['outputs']).intersection(set(ops[j]['inputs']))) > 0:
+ adj_mat[i][j] = 1
+ if len(set(ops[i]['inputs']).intersection(set(ops[j]['outputs']))) > 0:
+ adj_mat[j][i] = 1
+ return adj_mat
+
+
+"""Generate Compressed Sparse Row format (CSR) of a adjacency matrix. Details on CSR are given at
+https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format).
+"""
+
+
+def get_csr(adj_matrix):
+ row_ptr = []
+ col_ind = []
+ assert (adj_matrix.shape[0] == adj_matrix.shape[1])
+ n = adj_matrix.shape[0]
+ cnt = 0
+ for i in range(n):
+ first = True
+ for j in range(n):
+ if adj_matrix[i][j] == 1:
+ col_ind.append(j)
+ if first == True:
+ first = False
+ row_ptr.append(cnt)
+ cnt += 1
+ row_ptr.append(cnt)
+ return row_ptr, col_ind
+
+
+"""Perform basic spectral clustering given a tflite model. The graph in this case is symmetric, undirected with
+unit weight per edge. Therefore, the spectral clustering is performed on a binary (0-1) adjacency matrix derived
+from the tflite model.
+"""
+
+
+def spectral_cluster(tflite_file):
+ adj_matrix = generate_adj_matrix(tflite_file)
+ L = np.diag(np.sum(adj_matrix, axis=0)) - adj_matrix
+ e_val, e_vec = np.linalg.eig(L)
+ vecs = e_vec[:, np.argsort(e_val)]
+ return vecs.T[1]
--- /dev/null
+#! /usr/bin/python
+import json
+from queue import LifoQueue
+
+
+def get_runtime_per_operation(trace_file):
+ with open(trace_file) as ifile:
+ data = json.load(ifile)
+ traceEvents = data['traceEvents']
+ time_val = {}
+ stack = LifoQueue(maxsize=1000)
+ for t in traceEvents:
+ if t == {}:
+ continue
+ if (t["name"].lower() != "graph" and "permute" not in t["name"].lower()) and \
+ ("subg" not in t["name"].lower() and "permute" not in t["name"].lower()):
+ if t["ph"] == "B":
+ stack.put((t["name"], int(t["ts"])))
+ elif t["ph"] == "E":
+ opname, st_time = stack.get()
+ assert (opname == t["name"])
+ if "$" in t["name"]:
+ time_val[int(
+ t["name"].split(" ")[0].lstrip('$'))] = int(t["ts"]) - st_time
+ else:
+ time_val[int(
+ t["name"].split(" ")[0].lstrip('@'))] = int(t["ts"]) - st_time
+
+ time_idx = [y for x, y in (sorted(time_val.items(), key=lambda item: item[0]))]
+ return time_idx
--- /dev/null
+#! /usr/bin/python
+
+import argparse
+import os
+import Graph
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ "test_partition.py", description="Example code to partition models")
+ parser.add_argument("modelfile", type=str, help="TFLite file with path")
+ parser.add_argument("tracefile", type=str, help="Chrome trace file with path")
+ parser.add_argument("--num_parts", type=int, default=2, help="Number of partitions")
+ parser.add_argument(
+ "--num_runs", type=int, default=10, help="Number of runs (topological orderings)")
+
+ # Parse arguments
+ args = parser.parse_args()
+
+ # Partition
+ g = Graph.GraphTopology(args.modelfile, args.tracefile)
+ g.partition_minmax_multiple(K=args.num_parts, nruns=args.num_runs)
}
if [ $# -eq 0 ]; then
- echo "For help, type $progname -h"
+ >&2 echo "For help, type $progname -h"
exit 1
fi
shift $((OPTIND-1))
if [ $# -ne 1 ]; then
- echo "error: wrong argument (no argument or too many arguments)."
- echo "For help, type $progname -h"
+ >&2 echo "error: wrong argument (no argument or too many arguments)."
+ >&2 echo "For help, type $progname -h"
exit 1
fi
modelfile=$(basename "$1")
if [[ "$modelfile" != *.* ]]; then
- echo "error: modelfile does not have extension."
- echo "Please provide extension so that $progname can identify what type of model you use."
+ >&2 echo "error: modelfile does not have extension."
+ >&2 echo "Please provide extension so that $progname can identify what type of model you use."
exit 1
fi
if [ ! -e $1 ]; then
- echo "error: "$1" does not exist."
+ >&2 echo "error: "$1" does not exist."
exit 1
fi
fi
extension=${modelfile##*.}
-echo "Generating nnpackage "$name" in "$outdir""
+echo "$progname: Generating nnpackage "$name" in "$outdir""
mkdir -p "$outdir"/"$name"/metadata
if [ -s "$config_src" ]; then
+++ /dev/null
-# tflite2circle
-
-`tflite2circle` is a tool to convert tflite into circle.
-
-## Usage
-
-```
-Usage: tflite2circle.sh [options] tflite
-Convert tflite to circle
-
-Returns
- 0 success
- non-zero failure
-
-Options:
- -h show this help
- -o set output directory (default=.)
-
-Environment variables:
- flatc path to flatc
- (default=./build/externals/FLATBUFFERS/build/flatc)
- tflite_schema path to schema.fbs
- (default=./externals/TENSORFLOW-1.12/tensorflow/contrib/lite/schema/schema.fbs)
-
-Examples:
- tflite2circle.sh Add_000.tflite => convert Add_000.tflite into Add_000.circle
- tflite2circle.sh -o my/circles Add_000 => convert Add_000.tflite into my/circles/Add_000.circle
-```
+++ /dev/null
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-'use strict'
-
-// read json and parse
-const fs = require('fs')
-let inputfile = "./03_2k.json"
-if (process.argv.length == 3)
- inputfile = process.argv[2]
-let raw = fs.readFileSync(inputfile)
-let model = JSON.parse(raw)
-
-// 0. prepare shortcut variables with object destructuring
-const { operators, tensors } = model.subgraphs[0]
-
-//--------------------------------------------------------------------------
-// 0. construct infra
-
-// List : opcode index (number) => op name (string)
-let opcodeIdxToOpName = []
-for (const opcode of model.operator_codes) {
- opcodeIdxToOpName.push(opcode.builtin_code)
-}
-
-// List: tensor index (number) => producing operator's index (number)
-// assume there is only one op that produces given output tensor.
-let defOp = []
-for (let i = 0; i < operators.length; ++i) {
- let op = operators[i]
- if (op.outputs.length !== 1) {
- console.debug("Assumption failed. Multiple output operator exists.")
- process.exit(-1);
- }
- defOp[op.outputs[0]] = i
-}
-
-// List: tensor index (number) => consuming operator indices (list of number)
-// Note that there may be multiple consumer ops for a given tensor index
-let useOps = []
-for (let i = 0; i < operators.length; ++i) {
- let op = operators[i]
- for (let inTensorIdx of op.inputs) {
- if (useOps[inTensorIdx])
- useOps[inTensorIdx].push(i)
- else
- useOps[inTensorIdx] = [ i ]
- }
-}
-
-// return operator that defines the given tensor index
-function getDefOp(iTensor) {
- return defOp[iTensor] === undefined ? undefined : operators[defOp[iTensor]]
-}
-
-function getUseOps(iTensor) {
- if (useOps[iTensor] === undefined)
- return undefined
- let ret = []
- for (let i of useOps[iTensor])
- ret.push(operators[i])
- return ret
-}
-
-function opeq(op, str) {
- return op === undefined ? undefined : opcodeIdxToOpName[op.opcode_index] === str
-}
-
-function hasUndefined() {
- for (let arg of arguments)
- if (arg === undefined)
- return true
- return false
-}
-
-//--------------------------------------------------------------------------
-// find SquaredDifference as starting point
-let squaredDifferenceIdxList = []
-for (let i = 0; i < operators.length; ++i) {
- if (opeq(operators[i], "SQUARED_DIFFERENCE"))
- squaredDifferenceIdxList.push(i)
-}
-
-let instanceNormList = [ ]
-for (let idx of squaredDifferenceIdxList) {
- const sqd1 = operators[idx]
- const findMean0AndInstanceNormInputTensor = function(sqd1) {
- let mean0, iInstanceNormInputTensor
- for (let i = 0; i < sqd1.inputs.length; ++i) {
- let op = getDefOp(sqd1.inputs[i])
- if (opeq(op, "MEAN")) {
- mean0 = op
- // let's check one of inputs are instance_norm
- // the other input is axis of mean operator.
- for (let j = 0; j < mean0.inputs.length; ++j) {
- // 1 - i means the other input of squared_difference.
- if (mean0.inputs[j] === sqd1.inputs[1 - i]) {
- iInstanceNormInputTensor = mean0.inputs[j]
- }
- if (!hasUndefined(iInstanceNormInputTensor)) break // found instance_norm
- }
- }
- if (!hasUndefined(mean0, iInstanceNormInputTensor)) break
- }
- return [mean0, iInstanceNormInputTensor]
- }
- const [mean0, iInstanceNormInputTensor] = findMean0AndInstanceNormInputTensor(sqd1)
- if (hasUndefined(mean0, iInstanceNormInputTensor)) continue
-
- const findConsumer = function(op, expectedOp) {
- let ops = getUseOps(op.outputs[0])
- if (ops === undefined || ops.length !== 1 || !opeq(ops[0], expectedOp))
- return undefined
- return ops[0]
- }
- const mean2 = findConsumer(sqd1, "MEAN")
- if (hasUndefined(mean2)) continue
-
- const add3 = findConsumer(mean2, "ADD")
- if (hasUndefined(add3)) continue
-
- const isScalar = function(tsr) { return tsr.shape.length === 0 }
- const is1D = function(tsr) { return tsr.shape.length === 1 }
- const isFloat32 = function(tsr) { return tsr.type === "FLOAT32" }
- const asFloat32 = function(arr) { return new Float32Array(new Uint8Array(arr).buffer)[0]; }
- const getFloatScalarValueFromInputsOf = function(op) {
- for (let i of op.inputs) {
- if (isScalar(tensors[i]) && isFloat32(tensors[i])) {
- let buf = model.buffers[tensors[i].buffer]
- if (buf.data && buf.data.length === 4)
- return asFloat32(buf.data)
- }
- }
- return undefined
- }
- const epsilon = getFloatScalarValueFromInputsOf(add3)
- if (hasUndefined(epsilon)) continue
-
- const rsqrt4 = findConsumer(add3, "RSQRT")
- if (hasUndefined(rsqrt4)) continue
-
- const mul5 = findConsumer(rsqrt4, "MUL")
- if (hasUndefined(mul5)) continue
-
- const getFloat1DTensorIdxFromInputsOf = function(op) {
- for (let i of op.inputs) {
- if (is1D(tensors[i]) && isFloat32(tensors[i]))
- return i
- }
- return undefined
- }
- const iGamma = getFloat1DTensorIdxFromInputsOf(mul5)
- if (hasUndefined(iGamma)) continue
-
- let mul6, mul7
- for (let i of useOps[mul5.outputs[0]]) {
- const op = operators[i]
- if (opcodeIdxToOpName[op.opcode_index] !== "MUL")
- break;
- const otherInput = op.inputs[0] === mul5.outputs[0] ? op.inputs[1] : op.inputs[0]
- if (otherInput === iInstanceNormInputTensor)
- mul6 = op
- else if (otherInput === mean0.outputs[0])
- mul7 = op
- }
- if (hasUndefined(mul6, mul7)) continue
-
- const sub8 = findConsumer(mul7, "SUB")
- if (hasUndefined(sub8)) continue
-
- const iBeta = getFloat1DTensorIdxFromInputsOf(sub8)
- if (hasUndefined(iBeta)) continue
-
- const add9 = findConsumer(sub8, "ADD")
- if (hasUndefined(add9)) continue
-
- const add9_2 = findConsumer(mul6, "ADD")
- if (hasUndefined(add9_2)) continue
-
- if (add9 !== add9_2)
- continue
-
- const getActivation = function(op) {
- return op.builtin_options.fused_activation_function
- }
- const activation = getActivation(add9)
- if (hasUndefined(activation)) continue
-
- //--------------------------------------------------------------------------
- // convert to instance norm
- let instanceNormOpcodeIdx = model.operator_codes.findIndex(o => { return o.builtin_code === "INSTANCE_NORM" })
- opcodeIdxToOpName.indexOf('INSTANCE_NORM')
- if (instanceNormOpcodeIdx === -1) {
- model.operator_codes.push( { "builtin_code": "INSTANCE_NORM", "version": 1 } )
- instanceNormOpcodeIdx = model.operator_codes.length - 1;
- }
- // construct instance norm operator
- let instanceNorm = {
- "opcode_index": instanceNormOpcodeIdx,
- "inputs": [ iInstanceNormInputTensor, iGamma, iBeta ],
- "outputs": [ add9.outputs[0] ],
- "builtin_options": { "epsilon": epsilon, "fused_activation_function": activation },
- "builtin_options_type": "InstanceNormOptions",
- "custom_options_format": "FLEXBUFFERS",
- "mutating_variable_inputs": [],
- }
- // add instance norm after removing 0~9 nodes
- instanceNormList.push(instanceNorm)
-} // end of sqd1
-let adjust = 0
-for (let i = 0; i < squaredDifferenceIdxList.length; ++i) {
- let idx = squaredDifferenceIdxList[i] + adjust
- operators.splice(idx - 1, 10, instanceNormList[i])
- adjust += -9
-}
-let raw_fused = JSON.stringify(model)
-fs.writeFileSync(inputfile+".fused", raw_fused);
+++ /dev/null
-#!/bin/bash
-
-set -u
-
-progname=$(basename "${BASH_SOURCE[0]}")
-script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
-nnfw_root="$( cd "${script_dir%*/*/*/*}" && pwd )"
-outdir="."
-flatc=${flatc:-"$nnfw_root/build/externals/FLATBUFFERS/build/flatc"}
-tflite_schema=${tflite_schema:-"$nnfw_root/externals/TENSORFLOW-1.13.1/tensorflow/lite/schema/schema.fbs"}
-circle_schema=${circle_schema:-"$nnfw_root/nnpackage/schema/circle_schema.fbs"}
-
-if ! [ -x "$flatc" ]; then
- echo "Please make sure `flatc` is in path."
- exit 2
-fi
-
-if ! { [ -e "$tflite_schema" ] && [ -e "$circle_schema" ]; }; then
- echo "Please make sure that the `*.fbs` paths are set properly."
- exit 3
-fi
-
-usage() {
- echo "Usage: $progname [options] tflite"
- echo "Convert tflite to circle"
- echo ""
- echo "Returns"
- echo " 0 success"
- echo " non-zero failure"
- echo ""
- echo "Options:"
- echo " -h show this help"
- echo " -o set output directory (default=$outdir)"
- echo ""
- echo "Environment variables:"
- echo " flatc path to flatc"
- echo " (default=./build/externals/FLATBUFFERS/build/flatc)"
- echo " tflite_schema path to tflite schema (i.e. schema.fbs)"
- echo " (default=./externals/TENSORFLOW-1.12/tensorflow/contrib/lite/schema/schema.fbs)"
- echo " circle_schema path to circle schema"
- echo " (default=./nnpackage/schema/circle_schema.fbs)"
- echo ""
- echo "Examples:"
- echo " $progname Add_000.tflite => convert Add_000.tflite into Add_000.circle"
- echo " $progname -o my/circles Add_000 => convert Add_000.tflite into my/circles/Add_000.circle"
- exit 1
-}
-
-if [ $# -eq 0 ]; then
- echo "For help, type $progname -h"
- exit 1
-fi
-
-while getopts "ho:" OPTION; do
-case "${OPTION}" in
- h) usage;;
- o) outdir=$OPTARG;;
- ?) exit 1;;
-esac
-done
-
-shift $((OPTIND-1))
-
-if [ $# -ne 1 ]; then
- echo "error: wrong argument (no argument or too many arguments)."
- echo "For help, type $progname -h"
- exit 1
-fi
-
-tflite_base=$(basename "$1")
-name=${tflite_base%.*}
-
-# convert
-
-mkdir -p "${outdir}"
-${flatc} -o ${outdir} --strict-json -t ${tflite_schema} -- $1
-${script_dir}/tflitejson2circlejson.py "${outdir}/${name}.json" > "${outdir}/${name}.circle"
-${flatc} -o ${outdir} -b ${circle_schema} "${outdir}/${name}.circle"
-rm -f ${outdir}/${name}.json
+++ /dev/null
-#!/usr/bin/python3
-
-# Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import os
-import sys
-from collections import OrderedDict
-
-
-def usage():
- script = os.path.basename(os.path.basename(__file__))
- print("Usage: {} path_to_tflite_in_json".format(script))
- sys.exit(-1)
-
-
-if __name__ == '__main__':
- if len(sys.argv) != 2:
- usage()
-
- json_path = sys.argv[1]
- with open(json_path, "r") as f:
- try:
- json_dict = json.load(f, object_pairs_hook=OrderedDict)
- json_dict["version"] = 0
- print(json.dumps(json_dict, indent=2))
- except KeyError:
- print("subgraphs attribute does not exist.")
- sys.exit(-2)
current_version=${version_line#"Version:"}
if [ $nightly -eq 0 ]; then
- # Get head commit's date
- pushd $nnfw_root > /dev/null
- date=$(git log -1 --format=%ad --date=format:%y%m%d)
- echo $current_version-nightly-$date
- popd > /dev/null
+ echo $current_version~$(date "+%y%m%d%H")
else
echo $current_version
fi
-flatbuffers>=1.12
+flatbuffers==1.12
numpy