[loco] Introduce TensorReduce Operation (#7497)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Wed, 18 Sep 2019 06:17:24 +0000 (15:17 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 18 Sep 2019 06:17:24 +0000 (15:17 +0900)
* [loco] Introduce TensorReduce Operation

This commit will introduce `TensorReduce` operation in loco

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
* fix wrong shapeInference

* revise comments

compiler/loco/include/loco/IR/CanonicalNodes.lst
compiler/loco/include/loco/IR/Nodes.h
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/TypeInference.cpp

index 32d0629..71fc8d0 100644 (file)
@@ -36,5 +36,6 @@ CANONICAL_NODE(Tanh, Tanh)
 CANONICAL_NODE(TensorConcat, TensorConcat)
 CANONICAL_NODE(TensorBiasAdd, BiasAdd<Domain::Tensor>)
 CANONICAL_NODE(TensorBroadcast, TensorBroadcast)
+CANONICAL_NODE(TensorReduce, TensorReduce)
 CANONICAL_NODE(TensorSoftmax, Softmax<Domain::Tensor>)
 CANONICAL_NODE(TransposedConv2D, TransposedConv2D)
index 316b263..6ffad19 100644 (file)
@@ -27,6 +27,7 @@
 #include "loco/IR/Stride.h"
 #include "loco/IR/Padding2D.h"
 #include "loco/IR/TensorAxis.h"
+#include "loco/IR/TensorAxisSet.h"
 #include "loco/IR/FeatureCodec.h"
 #include "loco/IR/FilterCodec.h"
 #include "loco/IR/DepthwiseFilterCodec.h"
@@ -566,6 +567,39 @@ private:
 };
 
 /**
+ * @brief Reduce type functions
+ */
+enum class ReduceFunc
+{
+  Mean, // ReduceMean
+  // TODO Support other reduce operations
+};
+
+/**
+ * @brief Computes ReduceFunc operations for Tensor domain
+ * @note  All the reduce functions always keep dimensions
+ */
+class TensorReduce final
+    : public CanonicalNodeDef<CanonicalOpcode::TensorReduce, FixedArity<1>::Mixin>
+{
+public:
+  Node *input(void) const { return at(0)->node(); }
+  void input(Node *node) { at(0)->node(node); }
+
+public:
+  const TensorAxisSet *axes(void) const { return &_axes; }
+  TensorAxisSet *axes(void) { return &_axes; }
+
+public:
+  ReduceFunc func(void) const { return _func; }
+  void func(ReduceFunc func) { _func = func; }
+
+private:
+  TensorAxisSet _axes;
+  ReduceFunc _func;
+};
+
+/**
  * @brief 2D Transposed Convolution
  *
  * @note  TransposedConv2D have a few important conventions that IR users should
index 5934007..f0e4ec6 100644 (file)
@@ -547,6 +547,19 @@ public:
     return loco::NodeShape{tensor_shape};
   }
 
+  // CASE: TensorReduce
+  loco::NodeShape visit(const loco::TensorReduce *node) final
+  {
+    auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+    auto const tensor_rank = tensor_shape.rank();
+
+    for (uint32_t d = 0; d < tensor_rank; ++d)
+      if (node->axes()->defined(d))
+        tensor_shape.dim(d) = 1;
+
+    return loco::NodeShape{tensor_shape};
+  }
+
   // CASE: TensorSoftmax
   loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); }
 };
index 92cdbd5..53915a9 100644 (file)
@@ -146,6 +146,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
   loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
   loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); }
 };