[coco] Bag-level weight allocation (#1932)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 23 Oct 2018 00:46:44 +0000 (09:46 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 23 Oct 2018 00:46:44 +0000 (09:46 +0900)
This commit introduces a new allocate API in PlainWieghtContext<T> which
allows users to allocate a weight for each bag.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/coco/generic/include/coco/IR/PlainWeightContext.h
contrib/coco/generic/src/IR/Data.cpp
contrib/coco/generic/src/IR/Data.test.cpp

index 062cbdb..be50c39 100644 (file)
@@ -37,6 +37,19 @@ template <typename T> struct PlainWeightContext
 {
   virtual ~PlainWeightContext() = default;
 
+  /**
+   * @brief Allocate a weight space for a given blob
+   *
+   * @require the following code SHOULD work for any bag "b":
+   *   PlainWeightContext<T> ctx;
+   *
+   *   ctx.allocate(b);
+   *   auto span = ctx.weight(b);
+   *   assert(span.data() != nullptr);
+   *   assert(span.size() == bag->size());
+   */
+  virtual void allocate(const Bag *) = 0;
+  // WARN Depercated
   virtual void allocate(const KernelObject *) = 0;
 
   /**
index 34e57e3..ec0a297 100644 (file)
@@ -120,6 +120,12 @@ public:
   PlainWeightContextImpl(PlainWeightContextImpl &&) = delete;
 
 public:
+  void allocate(const coco::Bag *bag) override
+  {
+    assert(bag != nullptr);
+    _blob->allocate(bag, sizeof(T));
+  }
+
   void allocate(const coco::KernelObject *obj)
   {
     assert(obj != nullptr);
index 5c8b857..cd3bea6 100644 (file)
@@ -33,6 +33,34 @@ TEST(IR_DATA, construct)
   ASSERT_EQ(mutable_ptr->f32(), immutable_ptr->f32());
 }
 
+TEST(IR_DATA, allocate_and_link_bag)
+{
+  auto m = coco::Module::create();
+  auto d = coco::Data::create();
+
+  // Create a bag
+  auto bag = m->entity()->bag()->create(9);
+
+  // weight(...) SHOULD return a null-span for an invalid bag
+  {
+    auto span = d->f32()->weight(bag);
+
+    ASSERT_EQ(span.data(), nullptr);
+    ASSERT_EQ(span.size(), 0);
+  }
+
+  // Allocate a weight space
+  d->f32()->allocate(bag);
+
+  // weight(...) SHOULD return a valid for a valid bag
+  {
+    auto span = d->f32()->weight(bag);
+
+    ASSERT_NE(span.data(), nullptr);
+    ASSERT_EQ(span.size(), bag->size());
+  }
+}
+
 TEST(IR_DATA, allocate_and_link_kernel)
 {
   using nncc::core::ADT::kernel::num_elements;