From 3baa813ccdf34d2fe4f9d29d134426ed428106e7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 17 Aug 2018 09:39:49 +0900 Subject: [PATCH] [coco] Implement bag update in Arg (#1057) This commit moves Bag update implementation from Input/Output into Arg. Signed-off-by: Jonghyun Park --- contrib/coco/core/include/coco/IR/Arg.h | 11 ++++++++++ contrib/coco/core/include/coco/IR/Input.h | 9 +++----- contrib/coco/core/include/coco/IR/Output.h | 9 +++----- contrib/coco/core/src/IR/Arg.cpp | 19 ++++++++++++++++- contrib/coco/core/src/IR/Arg.test.cpp | 1 + contrib/coco/core/src/IR/Input.cpp | 34 ++++++++++++------------------ contrib/coco/core/src/IR/Input.test.cpp | 1 - contrib/coco/core/src/IR/Output.cpp | 34 ++++++++++++------------------ contrib/coco/core/src/IR/Output.test.cpp | 1 - 9 files changed, 62 insertions(+), 57 deletions(-) diff --git a/contrib/coco/core/include/coco/IR/Arg.h b/contrib/coco/core/include/coco/IR/Arg.h index 4eb59ff..4241247 100644 --- a/contrib/coco/core/include/coco/IR/Arg.h +++ b/contrib/coco/core/include/coco/IR/Arg.h @@ -31,6 +31,14 @@ public: const std::string &name(void) const { return _name; } void name(const std::string &s) { _name = s; } +protected: + virtual void onTake(Bag *) { return; } + virtual void onRelease(Bag *) { return; } + +public: + Bag *bag(void) const { return _bag; } + void bag(Bag *); + public: ElemID &at(const nncc::core::ADT::tensor::Index &); const ElemID &at(const nncc::core::ADT::tensor::Index &) const; @@ -40,6 +48,9 @@ private: private: std::string _name; + +private: + Bag *_bag; std::vector _map; }; diff --git a/contrib/coco/core/include/coco/IR/Input.h b/contrib/coco/core/include/coco/IR/Input.h index 995b665..0e87d1b 100644 --- a/contrib/coco/core/include/coco/IR/Input.h +++ b/contrib/coco/core/include/coco/IR/Input.h @@ -20,14 +20,11 @@ public: Input(const PtrLink *bag_link, const nncc::core::ADT::tensor::Shape &shape); private: - const PtrLink *const _bag_link; - -public: - Bag *bag(void) const { return _bag; } - void bag(Bag *); + void onTake(Bag *) override; + void onRelease(Bag *) override; private: - Bag *_bag; + const PtrLink *const _bag_link; }; } // namespace coco diff --git a/contrib/coco/core/include/coco/IR/Output.h b/contrib/coco/core/include/coco/IR/Output.h index e9c27db..d735afd 100644 --- a/contrib/coco/core/include/coco/IR/Output.h +++ b/contrib/coco/core/include/coco/IR/Output.h @@ -20,14 +20,11 @@ public: Output(const PtrLink *bag_link, const nncc::core::ADT::tensor::Shape &shape); private: - const PtrLink *const _bag_link; - -public: - Bag *bag(void) const { return _bag; } - void bag(Bag *); + void onTake(Bag *) override; + void onRelease(Bag *) override; private: - Bag *_bag; + const PtrLink *const _bag_link; }; } // namespace coco diff --git a/contrib/coco/core/src/IR/Arg.cpp b/contrib/coco/core/src/IR/Arg.cpp index f9e2e13..637cf0b 100644 --- a/contrib/coco/core/src/IR/Arg.cpp +++ b/contrib/coco/core/src/IR/Arg.cpp @@ -14,11 +14,28 @@ const nncc::core::ADT::tensor::LexicalLayout l; namespace coco { -Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape} +Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape}, _bag{nullptr} { _map.resize(nncc::core::ADT::tensor::num_elements(shape)); } +void Arg::bag(Bag *bag) +{ + if (_bag != nullptr) + { + onRelease(_bag); + _bag = nullptr; + } + + assert(_bag == nullptr); + + if (bag != nullptr) + { + _bag = bag; + onTake(_bag); + } +} + ElemID &Arg::at(const nncc::core::ADT::tensor::Index &index) { return _map.at(l.offset(_shape, index)); diff --git a/contrib/coco/core/src/IR/Arg.test.cpp b/contrib/coco/core/src/IR/Arg.test.cpp index 8674144..4a1861f 100644 --- a/contrib/coco/core/src/IR/Arg.test.cpp +++ b/contrib/coco/core/src/IR/Arg.test.cpp @@ -35,6 +35,7 @@ TEST_F(ArgTest, constructor) ASSERT_EQ(arg->shape(), shape); ASSERT_TRUE(arg->name().empty()); + ASSERT_EQ(arg->bag(), nullptr); } TEST_F(ArgTest, name_update) diff --git a/contrib/coco/core/src/IR/Input.cpp b/contrib/coco/core/src/IR/Input.cpp index 2b6e699..7e2afa3 100644 --- a/contrib/coco/core/src/IR/Input.cpp +++ b/contrib/coco/core/src/IR/Input.cpp @@ -7,35 +7,27 @@ namespace coco { Input::Input(const PtrLink *bag_link, const nncc::core::ADT::tensor::Shape &shape) - : Arg{shape}, _bag_link{bag_link}, _bag{nullptr} + : Arg{shape}, _bag_link{bag_link} { // DO NOT?HING } -void Input::bag(Bag *bag) +void Input::onTake(Bag *bag) { - if (_bag != nullptr) - { - auto info = _bag_link->find(_bag); - assert(info != nullptr); - assert(info->type() == BagType::Input); + auto info = _bag_link->find(bag); + assert(info != nullptr); + assert(info->type() == BagType::Intermediate); - info->type(BagType::Intermediate); - _bag = nullptr; - } - - assert(_bag == nullptr); - - if (bag != nullptr) - { - _bag = bag; + info->type(BagType::Input); +} - auto info = _bag_link->find(_bag); - assert(info != nullptr); - assert(info->type() == BagType::Intermediate); +void Input::onRelease(Bag *bag) +{ + auto info = _bag_link->find(bag); + assert(info != nullptr); + assert(info->type() == BagType::Input); - info->type(BagType::Input); - } + info->type(BagType::Intermediate); } } // namespace coco diff --git a/contrib/coco/core/src/IR/Input.test.cpp b/contrib/coco/core/src/IR/Input.test.cpp index 941ff6f..46a00ec 100644 --- a/contrib/coco/core/src/IR/Input.test.cpp +++ b/contrib/coco/core/src/IR/Input.test.cpp @@ -17,7 +17,6 @@ TEST(IR_INPUT, ctor_should_set_shape) coco::Input input{&bag_link, shape}; ASSERT_EQ(input.shape(), shape); - ASSERT_EQ(input.bag(), nullptr); ASSERT_TRUE(input.name().empty()); } diff --git a/contrib/coco/core/src/IR/Output.cpp b/contrib/coco/core/src/IR/Output.cpp index 39ed6cc..34a8f04 100644 --- a/contrib/coco/core/src/IR/Output.cpp +++ b/contrib/coco/core/src/IR/Output.cpp @@ -5,35 +5,27 @@ namespace coco { Output::Output(const PtrLink *bag_link, const nncc::core::ADT::tensor::Shape &shape) - : Arg{shape}, _bag_link{bag_link}, _bag{nullptr} + : Arg{shape}, _bag_link{bag_link} { // DO NOTHING } -void Output::bag(Bag *bag) +void Output::onTake(Bag *bag) { - if (_bag != nullptr) - { - auto info = _bag_link->find(_bag); - assert(info != nullptr); - assert(info->type() == BagType::Output); + auto info = _bag_link->find(bag); + assert(info != nullptr); + assert(info->type() == BagType::Intermediate); - info->type(BagType::Intermediate); - _bag = nullptr; - } - - assert(_bag == nullptr); - - if (bag != nullptr) - { - _bag = bag; + info->type(BagType::Output); +} - auto info = _bag_link->find(_bag); - assert(info != nullptr); - assert(info->type() == BagType::Intermediate); +void Output::onRelease(Bag *bag) +{ + auto info = _bag_link->find(bag); + assert(info != nullptr); + assert(info->type() == BagType::Output); - info->type(BagType::Output); - } + info->type(BagType::Intermediate); } } // namespace coco diff --git a/contrib/coco/core/src/IR/Output.test.cpp b/contrib/coco/core/src/IR/Output.test.cpp index 45c1525..3eba1da 100644 --- a/contrib/coco/core/src/IR/Output.test.cpp +++ b/contrib/coco/core/src/IR/Output.test.cpp @@ -17,7 +17,6 @@ TEST(IR_OUTPUT, ctor_should_set_shape) coco::Output output{&bag_link, shape}; ASSERT_EQ(output.shape(), shape); - ASSERT_EQ(output.bag(), nullptr); } TEST(IR_OUTPUT, bag_update) -- 2.7.4