[enco] Support constant folding for add, sub, mul, div, sqrt (#2821)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Mon, 14 Jan 2019 02:18:18 +0000 (11:18 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 14 Jan 2019 02:18:18 +0000 (11:18 +0900)
* [enco] Support constant folding for add, sub, mul, div, sqrt

Until now, only `coco::Copy` instruction was available for constant folding.
This commit will support some `coco::Eval` instructions in constant folding,
which is add, sub, mul, div, sqrt.

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
* minor modifications

contrib/enco/core/src/Transforms/ConstantFolding.cpp

index 2c57b8f..61374d2 100644 (file)
@@ -18,7 +18,7 @@
 #include "Session.h"
 
 #include <queue>
-
+#include <cmath>
 #include <cassert>
 
 namespace
@@ -168,6 +168,212 @@ void fold_constant(std::queue<coco::Bag *> &q, coco::Copy *copy)
   q.push(dst_bag);
 }
 
+template <typename Callable>
+void fold_constant_op(std::queue<coco::Bag *> &q, coco::UnaryOp *op, Callable evaluate)
+{
+  auto m = op->module();
+  auto d = enco::data(m);
+
+  auto ins = op->parent();
+  auto eval = ins->asEval();
+
+  // UnaryOp has only one arg
+  auto src_obj = *(op->uses().begin());
+  auto src_bag = src_obj->bag();
+
+  auto dst_obj = eval->out();
+  auto dst_bag = dst_obj->bag();
+
+  assert(d->allocated(src_bag));
+  assert(!d->allocated(dst_bag));
+
+  // TODO Support other data type
+  auto src_span = d->f32()->weight(src_bag);
+  assert(src_span.data() != nullptr);
+
+  auto src_feature = src_obj->asFeature();
+  auto dst_feature = dst_obj->asFeature();
+
+  // TODO Support other object type
+  if (src_feature == nullptr || dst_feature == nullptr)
+  {
+    return;
+  }
+
+  assert(src_feature != nullptr);
+  assert(dst_feature != nullptr);
+
+  // Allocate weight for destination
+  d->f32()->allocate(dst_bag);
+  auto dst_span = d->f32()->weight(dst_bag);
+
+  assert(src_feature->layout()->batch() == dst_feature->layout()->batch());
+  assert(src_feature->layout()->depth() == dst_feature->layout()->depth());
+  assert(src_feature->layout()->height() == dst_feature->layout()->height());
+  assert(src_feature->layout()->width() == dst_feature->layout()->width());
+
+  uint32_t const B = src_feature->layout()->batch();
+  uint32_t const C = src_feature->layout()->depth();
+  uint32_t const H = src_feature->layout()->height();
+  uint32_t const W = src_feature->layout()->width();
+
+  for (uint32_t b = 0; b < B; ++b)
+  {
+    for (uint32_t ch = 0; ch < C; ++ch)
+    {
+      for (uint32_t row = 0; row < H; ++row)
+      {
+        for (uint32_t col = 0; col < W; ++col)
+        {
+          auto src_ind = src_feature->layout()->at(b, ch, row, col);
+          auto dst_ind = dst_feature->layout()->at(b, ch, row, col);
+
+          evaluate(&dst_span[dst_ind.value()], src_span[src_ind.value()]);
+        }
+      }
+    }
+  }
+
+  // Let's detach eval
+  eval->out(nullptr);
+  eval->detach();
+
+  // Let's visit destination bag!
+  q.push(dst_bag);
+}
+
+template <typename Callable>
+void fold_constant_op(std::queue<coco::Bag *> &q, coco::BinaryOp *op, Callable evaluate)
+{
+  auto m = op->module();
+  auto d = enco::data(m);
+
+  auto ins = op->parent();
+  auto eval = ins->asEval();
+
+  // Already folded by the other bag
+  if (!eval->out())
+  {
+    return;
+  }
+
+  auto lhs_load = op->left()->asLoad();
+  auto lhs_obj = lhs_load->object();
+  auto lhs_bag = lhs_obj->bag();
+
+  auto rhs_load = op->right()->asLoad();
+  auto rhs_obj = rhs_load->object();
+  auto rhs_bag = rhs_obj->bag();
+
+  auto dst_obj = eval->out();
+  auto dst_bag = dst_obj->bag();
+
+  // The other bag is non-constant
+  if (!d->allocated(lhs_bag) || !d->allocated(rhs_bag))
+  {
+    return;
+  }
+
+  assert(d->allocated(lhs_bag));
+  assert(d->allocated(rhs_bag));
+  assert(!d->allocated(dst_bag));
+
+  // TODO Support other data type
+  auto lhs_span = d->f32()->weight(lhs_bag);
+  auto rhs_span = d->f32()->weight(rhs_bag);
+  assert(lhs_span.data() != nullptr);
+  assert(rhs_span.data() != nullptr);
+
+  auto lhs_feature = lhs_obj->asFeature();
+  auto rhs_feature = rhs_obj->asFeature();
+  auto dst_feature = dst_obj->asFeature();
+
+  // TODO Support other object type
+  if (lhs_feature == nullptr || rhs_feature == nullptr || dst_feature == nullptr)
+  {
+    return;
+  }
+
+  assert(lhs_feature != nullptr);
+  assert(rhs_feature != nullptr);
+  assert(dst_feature != nullptr);
+
+  // Allocate weight for destination
+  d->f32()->allocate(dst_bag);
+  auto dst_span = d->f32()->weight(dst_bag);
+
+  assert(lhs_feature->layout()->batch() == rhs_feature->layout()->batch());
+  assert(lhs_feature->layout()->depth() == rhs_feature->layout()->depth());
+  assert(lhs_feature->layout()->height() == rhs_feature->layout()->height());
+  assert(lhs_feature->layout()->width() == rhs_feature->layout()->width());
+
+  assert(lhs_feature->layout()->batch() == dst_feature->layout()->batch());
+  assert(lhs_feature->layout()->depth() == dst_feature->layout()->depth());
+  assert(lhs_feature->layout()->height() == dst_feature->layout()->height());
+  assert(lhs_feature->layout()->width() == dst_feature->layout()->width());
+
+  uint32_t const B = lhs_feature->layout()->batch();
+  uint32_t const C = lhs_feature->layout()->depth();
+  uint32_t const H = lhs_feature->layout()->height();
+  uint32_t const W = lhs_feature->layout()->width();
+
+  for (uint32_t b = 0; b < B; ++b)
+  {
+    for (uint32_t ch = 0; ch < C; ++ch)
+    {
+      for (uint32_t row = 0; row < H; ++row)
+      {
+        for (uint32_t col = 0; col < W; ++col)
+        {
+          auto lhs_ind = lhs_feature->layout()->at(b, ch, row, col);
+          auto rhs_ind = rhs_feature->layout()->at(b, ch, row, col);
+          auto dst_ind = dst_feature->layout()->at(b, ch, row, col);
+
+          evaluate(&dst_span[dst_ind.value()], lhs_span[lhs_ind.value()],
+                   rhs_span[rhs_ind.value()]);
+        }
+      }
+    }
+  }
+
+  // Let's detach eval
+  eval->out(nullptr);
+  eval->detach();
+
+  // Let's visit destination bag!
+  q.push(dst_bag);
+}
+
+void fold_constant(std::queue<coco::Bag *> &q, coco::Eval *eval)
+{
+  // TODO Support other data types
+  if (auto op = eval->op()->asSqrt())
+  {
+    fold_constant_op(q, op, [](float *dst, float value) { *dst = std::sqrt(value); });
+  }
+  else if (auto op = eval->op()->asAdd())
+  {
+    fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs + rhs; });
+  }
+  else if (auto op = eval->op()->asSub())
+  {
+    fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs - rhs; });
+  }
+  else if (auto op = eval->op()->asMul())
+  {
+    fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs * rhs; });
+  }
+  else if (auto op = eval->op()->asDiv())
+  {
+    fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs / rhs; });
+  }
+  else
+  {
+    // Not supported opteration, do nothing
+    // TODO Support other operations
+  }
+}
+
 void fold_constant(std::queue<coco::Bag *> &q, coco::Instr *ins)
 {
   if (auto copy = coco::safe_cast<coco::Copy>(ins))
@@ -175,6 +381,12 @@ void fold_constant(std::queue<coco::Bag *> &q, coco::Instr *ins)
     fold_constant(q, copy);
     return;
   }
+  if (auto eval = coco::safe_cast<coco::Eval>(ins))
+  {
+    fold_constant(q, eval);
+    return;
+  }
+
   // TODO Add more cases for constant folding
 }