From: 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Wed, 29 May 2019 06:55:44 +0000 (+0900) Subject: [loco] Reshape with shape known at compile time (#3615) X-Git-Tag: nncc_backup~504 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bdf1316f074527194bfdef1485bbbd4242fc6c9f;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Reshape with shape known at compile time (#3615) * [loco] FixedReshape This commit adds a loco node that reshapes a tensor with statically known shape. Signed-off-by: Hyun Sik Yoon * use template. revise comment --- diff --git a/contrib/loco/include/loco/IR/Nodes.h b/contrib/loco/include/loco/IR/Nodes.h index 71e352d..c8e9c61 100644 --- a/contrib/loco/include/loco/IR/Nodes.h +++ b/contrib/loco/include/loco/IR/Nodes.h @@ -258,6 +258,39 @@ private: std::unique_ptr _dec{nullptr}; }; +enum class ReshapeType +{ + Fixed, // shape is known at compile time + // Add another type for a case when shape is not known at compile time +}; + +template class Reshape; + +/** + * @brief Reshape a tensor to another tensor whose shape is known at compile time + * + * @note This class reshapes the shape of an input tensor to _shape. + * Each dimension of _shape should be known at compile time. + * Any dimension of _shape should be greater than 0. + * + * Interpreter or runtime should lexicographically copy an input tensor into an output tensor. + * For example, values of an input tesor of shape [2, 2, 2, 2] will be copied into an output + * tensor of new shape [4, 4] like the following: + * input[0, 0, 0, 0] => output [0, 0] + * input[0, 0, 0, 1] => output [0, 1] + * input[0, 0, 1, 0] => output [0, 2] + * ... + * input[1, 1, 1, 1] => output [3, 3] + */ +template <> +class Reshape final : public FixedArityNode<1>, + public NodeMixin +{ +public: + Node *input(void) const { return at(0)->node(); } + void input(Node *node) { at(0)->node(node); } +}; + } // namespace loco #endif // __LOCO_IR_NODES_H__ diff --git a/contrib/loco/src/IR/Nodes.test.cpp b/contrib/loco/src/IR/Nodes.test.cpp index 7c1da7a..edd1ed8 100644 --- a/contrib/loco/src/IR/Nodes.test.cpp +++ b/contrib/loco/src/IR/Nodes.test.cpp @@ -146,3 +146,20 @@ TEST(FeatureDecodeTest, constructor) ASSERT_EQ(feature_decode.input(), nullptr); ASSERT_EQ(feature_decode.decoder(), nullptr); } + +TEST(Reshape_Fixed_Test, constructor) +{ + loco::Reshape reshape; + + ASSERT_EQ(reshape.rank(), 0); +} + +TEST(Reshape_Fixed_Test, shape) +{ + loco::Reshape reshape; + reshape.shape({2, 3}); + + ASSERT_EQ(reshape.rank(), 2); + ASSERT_EQ(reshape.dim(0), 2); + ASSERT_EQ(reshape.dim(1), 3); +}