[loco] Introduce Tanh Operation (#6830)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Thu, 22 Aug 2019 08:05:04 +0000 (17:05 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 22 Aug 2019 08:05:04 +0000 (17:05 +0900)
This commit will introduce `Tanh` operation in loco

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
compiler/loco/include/loco/IR/CanonicalNodes.lst
compiler/loco/include/loco/IR/Nodes.h
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/TypeInference.cpp

index 094ee81..2acc259 100644 (file)
@@ -31,6 +31,7 @@ CANONICAL_NODE(Push, Push)
 CANONICAL_NODE(Pull, Pull)
 CANONICAL_NODE(ReLU, ReLU)
 CANONICAL_NODE(ReLU6, ReLU6)
+CANONICAL_NODE(Tanh, Tanh)
 CANONICAL_NODE(TensorConcat, TensorConcat)
 CANONICAL_NODE(TensorBiasAdd, BiasAdd<Domain::Tensor>)
 CANONICAL_NODE(TensorSoftmax, Softmax<Domain::Tensor>)
index d86ed56..9f4a73d 100644 (file)
@@ -182,6 +182,19 @@ public:
 };
 
 /**
+ * @brief Create a new value that rectifies its input by tanh
+ */
+class Tanh final : public CanonicalNodeDef<CanonicalOpcode::Tanh, FixedArity<1>::Mixin>
+{
+public:
+  Tanh() = default;
+
+public:
+  Node *input(void) const { return at(0)->node(); }
+  void input(Node *node) { at(0)->node(node); }
+};
+
+/**
  * @brief Create a value from constant byte array
  *
  * @note ConstGen assumes "lexical memory layout".
index d561878..4e6a466 100644 (file)
@@ -409,6 +409,9 @@ public:
   // CASE: ReLU6
   loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); }
 
+  // CASE: Tanh
+  loco::NodeShape visit(const loco::Tanh *node) final { return loco::shape_get(node->input()); }
+
   // CASE: TensorBiasAdd
   loco::NodeShape visit(const loco::TensorBiasAdd *node) final
   {
index 2cf1716..3c47ed9 100644 (file)
@@ -119,6 +119,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::Pull *node) { return node->dtype(); }
   loco::DataType visit(const loco::ReLU *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::Tanh *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
   loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
   loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }