[enco] Unify Compatiblity Check and ANN IR Builder (#1534)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 18 Sep 2018 06:41:57 +0000 (15:41 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 18 Sep 2018 06:41:57 +0000 (15:41 +0900)
This commit revises ANNOpBuilder to return Appender for a given
instruction instead of appending it.

This change allows us to unify independent CompatibilityCheck into
ANNOpBuilder.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/enco/core/src/Transforms/Split.cpp

index d3d3347..678c4e8 100644 (file)
@@ -13,50 +13,43 @@ namespace
 
 using Appender = std::function<void(ANNBinder *binder)>;
 
-class ANNOpBuilder : public coco::Instr::Visitor<void>
+class ANNOpBuilder : public coco::Instr::Visitor<Appender>
 {
 public:
-  ANNOpBuilder(ANNBinder *binder, coco::Data *data) : _binder{binder}, _data{data}
+  ANNOpBuilder(coco::Data *data) : _data{data}
   {
     // DO NOTHING
   }
 
 public:
-  void visit(const coco::UnitF *unit)
+  Appender visit(const coco::UnitF *unit)
   {
     if (unit->op()->asConv2D())
     {
-      auto f = conv2d(unit);
-      f(_binder);
+      return conv2d(unit);
     }
     else if (unit->op()->asReLU())
     {
-      auto f = relu(unit);
-      f(_binder);
+      return relu(unit);
     }
     else if (unit->op()->asMaxPool2D())
     {
-      auto f = maxpool2d(unit);
-      f(_binder);
+      return maxpool2d(unit);
     }
     else if (unit->op()->asAvgPool2D())
     {
-      auto f = avgpool2d(unit);
-      f(_binder);
+      return avgpool2d(unit);
     }
     else if (unit->op()->asPadF())
     {
-      auto f = pad(unit);
-      f(_binder);
-    }
-    else
-    {
-      throw std::runtime_error{"Not supported, yet"};
+      return pad(unit);
     }
+
+    return nullptr;
   }
 
 public:
-  void visit(const coco::Shuffle *) { throw std::runtime_error{"Not supported, yet"}; }
+  Appender visit(const coco::Shuffle *) { return nullptr; }
 
 private:
   Appender conv2d(const coco::UnitF *unit) const
@@ -196,10 +189,18 @@ private:
 
   Appender avgpool2d(const coco::UnitF *unit)
   {
-    assert(unit->op()->asAvgPool2D());
+    auto avgpool = unit->op()->asAvgPool2D();
+    assert(avgpool != nullptr);
+
+    if (avgpool->divisor() != coco::AvgPool2D::Divisor::PaddingExcluded)
+    {
+      // When ANN runtime computes the average of each receptive field,
+      // it uses the number of valid(=non-padding) elements as a divisor.
+      return nullptr;
+    }
+
     // TODO Rename "_binder"
-    return [unit](ANNBinder *_binder) {
-      auto avgpool = unit->op()->asAvgPool2D();
+    return [unit, avgpool](ANNBinder *_binder) {
       auto ifm = _binder->addOperand<float>(unit->ifm());
 
       auto left = _binder->addOperand<int32_t>();
@@ -271,7 +272,6 @@ private:
   }
 
 private:
-  ANNBinder *const _binder;
   coco::Data *const _data;
 };
 
@@ -293,8 +293,8 @@ public:
   }
 
 public:
-  Compatibility kind(const coco::Instr *ins) const;
   Compatibility kind(const coco::Block *blk) const;
+  Compatibility kind(const Appender &appender) const;
 
 public:
   void build(void) const;
@@ -303,68 +303,9 @@ private:
   enco::Code *_code;
 };
 
-Compatibility ANNGroupBuilder::kind(const coco::Instr *ins) const
+Compatibility ANNGroupBuilder::kind(const Appender &app) const
 {
-  struct CompatibilityCheck final : public coco::Instr::DefaultVisitor<bool>,
-                                    public coco::Op::DefaultVisitor<bool>
-  {
-    //
-    // Instruction
-    //
-    bool visit(const coco::UnitF *unit) override
-    {
-      // TODO Check data layout
-      return unit->op()->accept(this);
-    }
-
-    bool visit(const coco::Shuffle *) override
-    {
-      // TODO Distinguish Reshape
-      return false;
-    }
-
-    //
-    // Op
-    //
-    bool visit(const coco::Conv2D *) override
-    {
-      // TODO Check data layout
-      return true;
-    }
-
-    bool visit(const coco::MaxPool2D *) override
-    {
-      // TODO Check data layout
-      return true;
-    }
-
-    bool visit(const coco::AvgPool2D *avgpool) override
-    {
-      if (avgpool->divisor() != coco::AvgPool2D::Divisor::PaddingExcluded)
-      {
-        // When ANN runtime computes the average of each receptive field,
-        // it uses the number of valid(=non-padding) elements as a divisor.
-        return false;
-      }
-
-      // TODO Check data layout
-      return true;
-    }
-
-    bool visit(const coco::ReLU *) override
-    {
-      // TODO Check data layout
-      return true;
-    }
-
-    bool visit(const coco::PadF *) override
-    {
-      // TODO Check data layout
-      return true;
-    }
-  };
-
-  return ins->accept(CompatibilityCheck{}) ? COMPATIBLE : INCOMPATIBLE;
+  return app ? COMPATIBLE : INCOMPATIBLE;
 }
 
 Compatibility ANNGroupBuilder::kind(const coco::Block *blk) const
@@ -375,6 +316,9 @@ Compatibility ANNGroupBuilder::kind(const coco::Block *blk) const
 void ANNGroupBuilder::build(void) const
 {
   auto m = _code->module();
+  auto d = _code->data();
+
+  ANNOpBuilder op_builder{d};
 
   // ANNGroupBuilder will construct a sequence of blocks from the original block sequence, and
   // a destination block (that dst_blk points to) is the tail of the generated sequence.
@@ -414,18 +358,28 @@ void ANNGroupBuilder::build(void) const
       ins = cur_ins->next();
       cur_ins->detach();
 
+      auto cur_append = cur_ins->accept(op_builder);
+
       // Create a new compatible block and use it as a destination block if the current
       // destination block is absent or incompatible with the instruction of intereset.
-      if ((dst_blk == nullptr) || (kind(cur_ins) != kind(dst_blk)))
+      if ((dst_blk == nullptr) || (kind(cur_append) != kind(dst_blk)))
       {
-        append(kind(cur_ins));
+        append(kind(cur_append));
       }
 
       assert(dst_blk != nullptr);
-      assert(kind(cur_ins) == kind(dst_blk));
+      assert(kind(cur_append) == kind(dst_blk));
 
       // Append ins to the dst_blk block
       dst_blk->instr()->append(cur_ins);
+
+      if (cur_append)
+      {
+        // Update Android NN IR if the current instruction is compatible
+        auto binder = _code->ann()->find(dst_blk);
+        assert(binder != nullptr);
+        cur_append(binder);
+      }
     }
 
     // Destroy the source block
@@ -508,13 +462,6 @@ void ANNModuleBuilder::build(void) const
     auto binder = _code->ann()->nth(n);
     auto block = binder->block();
 
-    ANNOpBuilder op_builder{binder, _code->data()};
-
-    for (auto ins = block->instr()->head(); ins; ins = ins->next())
-    {
-      ins->accept(op_builder);
-    }
-
     // Let's identify input/output bags
     binder->identifyInputs(inputs(binder->block()));
     binder->identifyOutputs(outputs(binder->block()));