[loco] Reshape with shape known at compile time (#3615)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 29 May 2019 06:55:44 +0000 (15:55 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 29 May 2019 06:55:44 +0000 (15:55 +0900)
* [loco] FixedReshape

This commit adds a loco node that reshapes a tensor with statically known shape.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* use template. revise comment

contrib/loco/include/loco/IR/Nodes.h
contrib/loco/src/IR/Nodes.test.cpp

index 71e352d..c8e9c61 100644 (file)
@@ -258,6 +258,39 @@ private:
   std::unique_ptr<FeatureDecoder> _dec{nullptr};
 };
 
+enum class ReshapeType
+{
+  Fixed, // shape is known at compile time
+  // Add another type for a case when shape is not known at compile time
+};
+
+template <ReshapeType RT> class Reshape;
+
+/**
+ * @brief Reshape a tensor to another tensor whose shape is known at compile time
+ *
+ * @note This class reshapes the shape of an input tensor to _shape.
+ *       Each dimension of _shape should be known at compile time.
+ *       Any dimension of _shape should be greater than 0.
+ *
+ *       Interpreter or runtime should lexicographically copy an input tensor into an output tensor.
+ *       For example, values of an input tesor of shape [2, 2, 2, 2] will be copied into an output
+ *       tensor of new shape [4, 4] like the following:
+ *         input[0, 0, 0, 0] => output [0, 0]
+ *         input[0, 0, 0, 1] => output [0, 1]
+ *         input[0, 0, 1, 0] => output [0, 2]
+ *         ...
+ *         input[1, 1, 1, 1] => output [3, 3]
+ */
+template <>
+class Reshape<ReshapeType::Fixed> final : public FixedArityNode<1>,
+                                          public NodeMixin<NodeTrait::TensorShape>
+{
+public:
+  Node *input(void) const { return at(0)->node(); }
+  void input(Node *node) { at(0)->node(node); }
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_NODES_H__
index 7c1da7a..edd1ed8 100644 (file)
@@ -146,3 +146,20 @@ TEST(FeatureDecodeTest, constructor)
   ASSERT_EQ(feature_decode.input(), nullptr);
   ASSERT_EQ(feature_decode.decoder(), nullptr);
 }
+
+TEST(Reshape_Fixed_Test, constructor)
+{
+  loco::Reshape<loco::ReshapeType::Fixed> reshape;
+
+  ASSERT_EQ(reshape.rank(), 0);
+}
+
+TEST(Reshape_Fixed_Test, shape)
+{
+  loco::Reshape<loco::ReshapeType::Fixed> reshape;
+  reshape.shape({2, 3});
+
+  ASSERT_EQ(reshape.rank(), 2);
+  ASSERT_EQ(reshape.dim(0), 2);
+  ASSERT_EQ(reshape.dim(1), 3);
+}