From: 남궁석/On-Device Lab(SR)/Engineer/삼성전자 Date: Wed, 18 Sep 2019 06:17:24 +0000 (+0900) Subject: [loco] Introduce TensorReduce Operation (#7497) X-Git-Tag: submit/tizen/20191205.083104~1194 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0bfe82f519525f394c296cad9253fdb5dfa79d2d;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Introduce TensorReduce Operation (#7497) * [loco] Introduce TensorReduce Operation This commit will introduce `TensorReduce` operation in loco Signed-off-by: Seok NamKoong * fix wrong shapeInference * revise comments --- diff --git a/compiler/loco/include/loco/IR/CanonicalNodes.lst b/compiler/loco/include/loco/IR/CanonicalNodes.lst index 32d0629..71fc8d0 100644 --- a/compiler/loco/include/loco/IR/CanonicalNodes.lst +++ b/compiler/loco/include/loco/IR/CanonicalNodes.lst @@ -36,5 +36,6 @@ CANONICAL_NODE(Tanh, Tanh) CANONICAL_NODE(TensorConcat, TensorConcat) CANONICAL_NODE(TensorBiasAdd, BiasAdd) CANONICAL_NODE(TensorBroadcast, TensorBroadcast) +CANONICAL_NODE(TensorReduce, TensorReduce) CANONICAL_NODE(TensorSoftmax, Softmax) CANONICAL_NODE(TransposedConv2D, TransposedConv2D) diff --git a/compiler/loco/include/loco/IR/Nodes.h b/compiler/loco/include/loco/IR/Nodes.h index 316b263..6ffad19 100644 --- a/compiler/loco/include/loco/IR/Nodes.h +++ b/compiler/loco/include/loco/IR/Nodes.h @@ -27,6 +27,7 @@ #include "loco/IR/Stride.h" #include "loco/IR/Padding2D.h" #include "loco/IR/TensorAxis.h" +#include "loco/IR/TensorAxisSet.h" #include "loco/IR/FeatureCodec.h" #include "loco/IR/FilterCodec.h" #include "loco/IR/DepthwiseFilterCodec.h" @@ -566,6 +567,39 @@ private: }; /** + * @brief Reduce type functions + */ +enum class ReduceFunc +{ + Mean, // ReduceMean + // TODO Support other reduce operations +}; + +/** + * @brief Computes ReduceFunc operations for Tensor domain + * @note All the reduce functions always keep dimensions + */ +class TensorReduce final + : public CanonicalNodeDef::Mixin> +{ +public: + Node *input(void) const { return at(0)->node(); } + void input(Node *node) { at(0)->node(node); } + +public: + const TensorAxisSet *axes(void) const { return &_axes; } + TensorAxisSet *axes(void) { return &_axes; } + +public: + ReduceFunc func(void) const { return _func; } + void func(ReduceFunc func) { _func = func; } + +private: + TensorAxisSet _axes; + ReduceFunc _func; +}; + +/** * @brief 2D Transposed Convolution * * @note TransposedConv2D have a few important conventions that IR users should diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 5934007..f0e4ec6 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -547,6 +547,19 @@ public: return loco::NodeShape{tensor_shape}; } + // CASE: TensorReduce + loco::NodeShape visit(const loco::TensorReduce *node) final + { + auto tensor_shape = node_shape(node->input()).as(); + auto const tensor_rank = tensor_shape.rank(); + + for (uint32_t d = 0; d < tensor_rank; ++d) + if (node->axes()->defined(d)) + tensor_shape.dim(d) = 1; + + return loco::NodeShape{tensor_shape}; + } + // CASE: TensorSoftmax loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); } }; diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp index 92cdbd5..53915a9 100644 --- a/compiler/loco/src/Service/TypeInference.cpp +++ b/compiler/loco/src/Service/TypeInference.cpp @@ -146,6 +146,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitorlhs()); } loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); } loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); } loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); } loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); } };