From ceaac924c688120369977d17ed4b3d07d40555ec Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 14 Jun 2019 14:08:23 +0900 Subject: [PATCH] [loco] Introduce BiasEncode (#3766) This commit extends loco with BiasEncode node which converts a tensor as a bias. Signed-off-by: Jonghyun Park --- contrib/loco/include/loco/IR/CanonicalNodes.lst | 1 + contrib/loco/include/loco/IR/Domain.h | 1 + contrib/loco/include/loco/IR/Nodes.h | 15 +++++++++++++++ contrib/loco/src/IR/Nodes.test.cpp | 10 ++++++++++ 4 files changed, 27 insertions(+) diff --git a/contrib/loco/include/loco/IR/CanonicalNodes.lst b/contrib/loco/include/loco/IR/CanonicalNodes.lst index 593d124..004e49d 100644 --- a/contrib/loco/include/loco/IR/CanonicalNodes.lst +++ b/contrib/loco/include/loco/IR/CanonicalNodes.lst @@ -8,6 +8,7 @@ // CANONICAL_NODE(OPCODE, CLASS) CANONICAL_NODE(AvgPool2D, AvgPool2D) +CANONICAL_NODE(BiasEncode, BiasEncode) CANONICAL_NODE(ConstGen, ConstGen) CANONICAL_NODE(Conv2D, Conv2D) CANONICAL_NODE(Forward, Forward) diff --git a/contrib/loco/include/loco/IR/Domain.h b/contrib/loco/include/loco/IR/Domain.h index d4ca518..7158ef3 100644 --- a/contrib/loco/include/loco/IR/Domain.h +++ b/contrib/loco/include/loco/IR/Domain.h @@ -42,6 +42,7 @@ enum class Domain Tensor, Feature, Filter, + Bias, /* ... */ }; diff --git a/contrib/loco/include/loco/IR/Nodes.h b/contrib/loco/include/loco/IR/Nodes.h index 89d8b56..0fb9e55 100644 --- a/contrib/loco/include/loco/IR/Nodes.h +++ b/contrib/loco/include/loco/IR/Nodes.h @@ -414,6 +414,21 @@ private: // TODO Support "Dilation" }; +/** + * @brief Create a "Bias" from a "Tensor" + * + * BiasEncode currently requires a rank-1 tensor as its input. + */ +class BiasEncode final : public FixedArityNode<1, CanonicalNodeImpl> +{ +public: + BiasEncode() = default; + +public: + Node *input(void) const { return at(0)->node(); } + void input(Node *node) { at(0)->node(node); } +}; + } // namespace loco #endif // __LOCO_IR_NODES_H__ diff --git a/contrib/loco/src/IR/Nodes.test.cpp b/contrib/loco/src/IR/Nodes.test.cpp index f618831..247a898 100644 --- a/contrib/loco/src/IR/Nodes.test.cpp +++ b/contrib/loco/src/IR/Nodes.test.cpp @@ -329,3 +329,13 @@ TEST(Conv2DTest, constructor) ASSERT_EQ(conv2d.stride()->vertical(), 1); ASSERT_EQ(conv2d.stride()->horizontal(), 1); } + +TEST(BiasEncodeTest, constructor) +{ + loco::BiasEncode bias_encode; + + ASSERT_EQ(bias_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(bias_encode.opcode(), loco::CanonicalOpcode::BiasEncode); + + ASSERT_EQ(bias_encode.input(), nullptr); +} -- 2.7.4