[loco] Update TensorShaped classes with initializer_list (#6284)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 6 Aug 2019 09:36:10 +0000 (18:36 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 6 Aug 2019 09:36:10 +0000 (18:36 +0900)
Now, "shape" setter in TensorShaped mixin takes initializer_list
as its input.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/include/loco/IR/Graph.h
compiler/loco/src/IR/Graph.cpp
compiler/loco/src/IR/Graph.test.cpp

index e0fa5a2..f94b4de 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "loco/ADT/ObjectPool.h"
 
+#include <initializer_list>
 #include <set>
 #include <string>
 #include <memory>
@@ -41,6 +42,7 @@ enum class Trait
   // Any "TensorShaped" class has the following methods
   // - const TensorShape *shape(void) const;
   // - void shape(std::unique_ptr<TensorShape> &&);
+  // - void shape(std::initializer_list<Dimension> &&);
   //
   // TODO Rename NodeMixin::TensorShape as NodeMixin::NDShape
   TensorShaped,
@@ -70,6 +72,7 @@ public:
 public:
   const TensorShape *shape(void) const { return _shape.get(); }
   void shape(std::unique_ptr<TensorShape> &&shape) { _shape = std::move(shape); }
+  void shape(std::initializer_list<Dimension> dims);
 
 private:
   std::unique_ptr<TensorShape> _shape = nullptr;
index 4c9dad5..c7c5984 100644 (file)
 
 #include <cassert>
 
+namespace
+{
+
+std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims)
+{
+  auto tensor_shape = stdex::make_unique<loco::TensorShape>();
+
+  tensor_shape->rank(dims.size());
+  {
+    uint32_t axis = 0;
+    for (auto it = dims.begin(); it != dims.end(); ++it)
+    {
+      tensor_shape->dim(axis++) = *it;
+    }
+    assert(axis == dims.size());
+  }
+
+  return std::move(tensor_shape);
+}
+
+} // namespace
+
 namespace loco
 {
 
+void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims)
+{
+  shape(make_tensor_shape(dims));
+}
+
 void GraphInput::node(Pull *pull)
 {
   _pull = pull;
index db74db1..acef83f 100644 (file)
@@ -59,6 +59,19 @@ TEST(DataTypedMixinTest, setter_and_getter)
   ASSERT_EQ(mixin.dtype(), loco::DataType::FLOAT32);
 }
 
+TEST(TensorShapedMixinTest, setter_and_getter)
+{
+  loco::Mixin<loco::Trait::TensorShaped> mixin;
+
+  mixin.shape({1, 2, 3, 4});
+  ASSERT_NE(mixin.shape(), nullptr);
+  ASSERT_EQ(mixin.shape()->rank(), 4);
+  ASSERT_EQ(mixin.shape()->dim(0), 1);
+  ASSERT_EQ(mixin.shape()->dim(1), 2);
+  ASSERT_EQ(mixin.shape()->dim(2), 3);
+  ASSERT_EQ(mixin.shape()->dim(3), 4);
+}
+
 TEST(GraphTest, create_and_destroy_node)
 {
   auto g = loco::make_graph();