[coco] Implement bag update in Arg (#1057)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 17 Aug 2018 00:39:49 +0000 (09:39 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 17 Aug 2018 00:39:49 +0000 (09:39 +0900)
This commit moves Bag update implementation from Input/Output into Arg.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/coco/core/include/coco/IR/Arg.h
contrib/coco/core/include/coco/IR/Input.h
contrib/coco/core/include/coco/IR/Output.h
contrib/coco/core/src/IR/Arg.cpp
contrib/coco/core/src/IR/Arg.test.cpp
contrib/coco/core/src/IR/Input.cpp
contrib/coco/core/src/IR/Input.test.cpp
contrib/coco/core/src/IR/Output.cpp
contrib/coco/core/src/IR/Output.test.cpp

index 4eb59ff..4241247 100644 (file)
@@ -31,6 +31,14 @@ public:
   const std::string &name(void) const { return _name; }
   void name(const std::string &s) { _name = s; }
 
+protected:
+  virtual void onTake(Bag *) { return; }
+  virtual void onRelease(Bag *) { return; }
+
+public:
+  Bag *bag(void) const { return _bag; }
+  void bag(Bag *);
+
 public:
   ElemID &at(const nncc::core::ADT::tensor::Index &);
   const ElemID &at(const nncc::core::ADT::tensor::Index &) const;
@@ -40,6 +48,9 @@ private:
 
 private:
   std::string _name;
+
+private:
+  Bag *_bag;
   std::vector<ElemID> _map;
 };
 
index 995b665..0e87d1b 100644 (file)
@@ -20,14 +20,11 @@ public:
   Input(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape);
 
 private:
-  const PtrLink<Bag, BagInfo> *const _bag_link;
-
-public:
-  Bag *bag(void) const { return _bag; }
-  void bag(Bag *);
+  void onTake(Bag *) override;
+  void onRelease(Bag *) override;
 
 private:
-  Bag *_bag;
+  const PtrLink<Bag, BagInfo> *const _bag_link;
 };
 
 } // namespace coco
index e9c27db..d735afd 100644 (file)
@@ -20,14 +20,11 @@ public:
   Output(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape);
 
 private:
-  const PtrLink<Bag, BagInfo> *const _bag_link;
-
-public:
-  Bag *bag(void) const { return _bag; }
-  void bag(Bag *);
+  void onTake(Bag *) override;
+  void onRelease(Bag *) override;
 
 private:
-  Bag *_bag;
+  const PtrLink<Bag, BagInfo> *const _bag_link;
 };
 
 } // namespace coco
index f9e2e13..637cf0b 100644 (file)
@@ -14,11 +14,28 @@ const nncc::core::ADT::tensor::LexicalLayout l;
 namespace coco
 {
 
-Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape}
+Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape}, _bag{nullptr}
 {
   _map.resize(nncc::core::ADT::tensor::num_elements(shape));
 }
 
+void Arg::bag(Bag *bag)
+{
+  if (_bag != nullptr)
+  {
+    onRelease(_bag);
+    _bag = nullptr;
+  }
+
+  assert(_bag == nullptr);
+
+  if (bag != nullptr)
+  {
+    _bag = bag;
+    onTake(_bag);
+  }
+}
+
 ElemID &Arg::at(const nncc::core::ADT::tensor::Index &index)
 {
   return _map.at(l.offset(_shape, index));
index 8674144..4a1861f 100644 (file)
@@ -35,6 +35,7 @@ TEST_F(ArgTest, constructor)
 
   ASSERT_EQ(arg->shape(), shape);
   ASSERT_TRUE(arg->name().empty());
+  ASSERT_EQ(arg->bag(), nullptr);
 }
 
 TEST_F(ArgTest, name_update)
index 2b6e699..7e2afa3 100644 (file)
@@ -7,35 +7,27 @@ namespace coco
 {
 
 Input::Input(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape)
-    : Arg{shape}, _bag_link{bag_link}, _bag{nullptr}
+    : Arg{shape}, _bag_link{bag_link}
 {
   // DO NOT?HING
 }
 
-void Input::bag(Bag *bag)
+void Input::onTake(Bag *bag)
 {
-  if (_bag != nullptr)
-  {
-    auto info = _bag_link->find(_bag);
-    assert(info != nullptr);
-    assert(info->type() == BagType::Input);
+  auto info = _bag_link->find(bag);
+  assert(info != nullptr);
+  assert(info->type() == BagType::Intermediate);
 
-    info->type(BagType::Intermediate);
-    _bag = nullptr;
-  }
-
-  assert(_bag == nullptr);
-
-  if (bag != nullptr)
-  {
-    _bag = bag;
+  info->type(BagType::Input);
+}
 
-    auto info = _bag_link->find(_bag);
-    assert(info != nullptr);
-    assert(info->type() == BagType::Intermediate);
+void Input::onRelease(Bag *bag)
+{
+  auto info = _bag_link->find(bag);
+  assert(info != nullptr);
+  assert(info->type() == BagType::Input);
 
-    info->type(BagType::Input);
-  }
+  info->type(BagType::Intermediate);
 }
 
 } // namespace coco
index 941ff6f..46a00ec 100644 (file)
@@ -17,7 +17,6 @@ TEST(IR_INPUT, ctor_should_set_shape)
   coco::Input input{&bag_link, shape};
 
   ASSERT_EQ(input.shape(), shape);
-  ASSERT_EQ(input.bag(), nullptr);
   ASSERT_TRUE(input.name().empty());
 }
 
index 39ed6cc..34a8f04 100644 (file)
@@ -5,35 +5,27 @@ namespace coco
 {
 
 Output::Output(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape)
-    : Arg{shape}, _bag_link{bag_link}, _bag{nullptr}
+    : Arg{shape}, _bag_link{bag_link}
 {
   // DO NOTHING
 }
 
-void Output::bag(Bag *bag)
+void Output::onTake(Bag *bag)
 {
-  if (_bag != nullptr)
-  {
-    auto info = _bag_link->find(_bag);
-    assert(info != nullptr);
-    assert(info->type() == BagType::Output);
+  auto info = _bag_link->find(bag);
+  assert(info != nullptr);
+  assert(info->type() == BagType::Intermediate);
 
-    info->type(BagType::Intermediate);
-    _bag = nullptr;
-  }
-
-  assert(_bag == nullptr);
-
-  if (bag != nullptr)
-  {
-    _bag = bag;
+  info->type(BagType::Output);
+}
 
-    auto info = _bag_link->find(_bag);
-    assert(info != nullptr);
-    assert(info->type() == BagType::Intermediate);
+void Output::onRelease(Bag *bag)
+{
+  auto info = _bag_link->find(bag);
+  assert(info != nullptr);
+  assert(info->type() == BagType::Output);
 
-    info->type(BagType::Output);
-  }
+  info->type(BagType::Intermediate);
 }
 
 } // namespace coco
index 45c1525..3eba1da 100644 (file)
@@ -17,7 +17,6 @@ TEST(IR_OUTPUT, ctor_should_set_shape)
   coco::Output output{&bag_link, shape};
 
   ASSERT_EQ(output.shape(), shape);
-  ASSERT_EQ(output.bag(), nullptr);
 }
 
 TEST(IR_OUTPUT, bag_update)