[exo] Refactoring: Extract code from big method (Fuse) (#8179)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 16 Oct 2019 08:03:47 +0000 (17:03 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 16 Oct 2019 08:03:47 +0000 (17:03 +0900)
A piece of code is extracted as a method (`create_fused_bias()`) from big method (`Fuse`).

For diff difficulty, old code will be replace with one single call in next PR.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo/src/Pass/FuseConv2DAddSubPass.cpp

index 75b15d0..26be34a 100644 (file)
@@ -135,13 +135,45 @@ private:
   TFLType *_fusable_node;
   locoex::TFLConst *_const_node;
   locoex::TFLConv2D *_conv2d_node;
+
+  locoex::TFLConst *create_fused_bias();
 };
 
+template <typename TFLType> locoex::TFLConst *Fuser<TFLType>::create_fused_bias()
+{
+  // we have to create a new bias by adding/substracting bias and const node (of TFLAdd or TFLSub)
+  auto bias = dynamic_cast<locoex::TFLConst *>(_conv2d_node->bias());
+  assert(bias->dtype() == loco::DataType::FLOAT32 &&
+         _const_node->dtype() == loco::DataType::FLOAT32);
+
+  assert(bias->rank() == 1 && _const_node->rank() == 1);
+  assert(bias->dim(0) == _const_node->dim(0));
+
+  // build a new bias
+  auto new_bias = _graph->nodes()->create<locoex::TFLConst>();
+  {
+    new_bias->dtype(loco::DataType::FLOAT32);
+
+    new_bias->rank(1);
+    new_bias->dim(0) = bias->dim(0);
+
+    new_bias->size<loco::DataType::FLOAT32>(bias->dim(0).value());
+
+    for (uint32_t x = 0; x < bias->dim(0).value(); x++)
+      new_bias->at<loco::DataType::FLOAT32>(x) = calc<TFLType>(
+          bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
+  }
+
+  return new_bias;
+}
+
 /**
  * @brief fuse TFLAdd or TFLSub into conv2d
  */
 template <typename TFLType> void Fuser<TFLType>::fuse(void)
 {
+  // TODO replace code below with `auto new_bias = create_fused_bias();`
+
   // conv2d bias could be either all 0 values or set with other values.
   // we have to create a new bias by adding/substracting const_node and bias
   auto bias = dynamic_cast<locoex::TFLConst *>(_conv2d_node->bias());
@@ -166,6 +198,7 @@ template <typename TFLType> void Fuser<TFLType>::fuse(void)
       new_bias->at<loco::DataType::FLOAT32>(x) = calc<TFLType>(
           bias->at<loco::DataType::FLOAT32>(x), _const_node->at<loco::DataType::FLOAT32>(x));
   }
+  // TODO replace the above
 
   // replace node with new_bias
   // note that loco::replace() is not used because bias could be input of other op just in case