This commit moves Bag update implementation from Input/Output into Arg.
Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
const std::string &name(void) const { return _name; }
void name(const std::string &s) { _name = s; }
+protected:
+ virtual void onTake(Bag *) { return; }
+ virtual void onRelease(Bag *) { return; }
+
+public:
+ Bag *bag(void) const { return _bag; }
+ void bag(Bag *);
+
public:
ElemID &at(const nncc::core::ADT::tensor::Index &);
const ElemID &at(const nncc::core::ADT::tensor::Index &) const;
private:
std::string _name;
+
+private:
+ Bag *_bag;
std::vector<ElemID> _map;
};
Input(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape);
private:
- const PtrLink<Bag, BagInfo> *const _bag_link;
-
-public:
- Bag *bag(void) const { return _bag; }
- void bag(Bag *);
+ void onTake(Bag *) override;
+ void onRelease(Bag *) override;
private:
- Bag *_bag;
+ const PtrLink<Bag, BagInfo> *const _bag_link;
};
} // namespace coco
Output(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape);
private:
- const PtrLink<Bag, BagInfo> *const _bag_link;
-
-public:
- Bag *bag(void) const { return _bag; }
- void bag(Bag *);
+ void onTake(Bag *) override;
+ void onRelease(Bag *) override;
private:
- Bag *_bag;
+ const PtrLink<Bag, BagInfo> *const _bag_link;
};
} // namespace coco
namespace coco
{
-Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape}
+Arg::Arg(const nncc::core::ADT::tensor::Shape &shape) : _shape{shape}, _bag{nullptr}
{
_map.resize(nncc::core::ADT::tensor::num_elements(shape));
}
+void Arg::bag(Bag *bag)
+{
+ if (_bag != nullptr)
+ {
+ onRelease(_bag);
+ _bag = nullptr;
+ }
+
+ assert(_bag == nullptr);
+
+ if (bag != nullptr)
+ {
+ _bag = bag;
+ onTake(_bag);
+ }
+}
+
ElemID &Arg::at(const nncc::core::ADT::tensor::Index &index)
{
return _map.at(l.offset(_shape, index));
ASSERT_EQ(arg->shape(), shape);
ASSERT_TRUE(arg->name().empty());
+ ASSERT_EQ(arg->bag(), nullptr);
}
TEST_F(ArgTest, name_update)
{
Input::Input(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape)
- : Arg{shape}, _bag_link{bag_link}, _bag{nullptr}
+ : Arg{shape}, _bag_link{bag_link}
{
// DO NOT?HING
}
-void Input::bag(Bag *bag)
+void Input::onTake(Bag *bag)
{
- if (_bag != nullptr)
- {
- auto info = _bag_link->find(_bag);
- assert(info != nullptr);
- assert(info->type() == BagType::Input);
+ auto info = _bag_link->find(bag);
+ assert(info != nullptr);
+ assert(info->type() == BagType::Intermediate);
- info->type(BagType::Intermediate);
- _bag = nullptr;
- }
-
- assert(_bag == nullptr);
-
- if (bag != nullptr)
- {
- _bag = bag;
+ info->type(BagType::Input);
+}
- auto info = _bag_link->find(_bag);
- assert(info != nullptr);
- assert(info->type() == BagType::Intermediate);
+void Input::onRelease(Bag *bag)
+{
+ auto info = _bag_link->find(bag);
+ assert(info != nullptr);
+ assert(info->type() == BagType::Input);
- info->type(BagType::Input);
- }
+ info->type(BagType::Intermediate);
}
} // namespace coco
coco::Input input{&bag_link, shape};
ASSERT_EQ(input.shape(), shape);
- ASSERT_EQ(input.bag(), nullptr);
ASSERT_TRUE(input.name().empty());
}
{
Output::Output(const PtrLink<Bag, BagInfo> *bag_link, const nncc::core::ADT::tensor::Shape &shape)
- : Arg{shape}, _bag_link{bag_link}, _bag{nullptr}
+ : Arg{shape}, _bag_link{bag_link}
{
// DO NOTHING
}
-void Output::bag(Bag *bag)
+void Output::onTake(Bag *bag)
{
- if (_bag != nullptr)
- {
- auto info = _bag_link->find(_bag);
- assert(info != nullptr);
- assert(info->type() == BagType::Output);
+ auto info = _bag_link->find(bag);
+ assert(info != nullptr);
+ assert(info->type() == BagType::Intermediate);
- info->type(BagType::Intermediate);
- _bag = nullptr;
- }
-
- assert(_bag == nullptr);
-
- if (bag != nullptr)
- {
- _bag = bag;
+ info->type(BagType::Output);
+}
- auto info = _bag_link->find(_bag);
- assert(info != nullptr);
- assert(info->type() == BagType::Intermediate);
+void Output::onRelease(Bag *bag)
+{
+ auto info = _bag_link->find(bag);
+ assert(info != nullptr);
+ assert(info->type() == BagType::Output);
- info->type(BagType::Output);
- }
+ info->type(BagType::Intermediate);
}
} // namespace coco
coco::Output output{&bag_link, shape};
ASSERT_EQ(output.shape(), shape);
- ASSERT_EQ(output.bag(), nullptr);
}
TEST(IR_OUTPUT, bag_update)