[exo] Revising setCandidate() of FuseBiasAddPass (#8538)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 28 Oct 2019 09:18:08 +0000 (18:18 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 28 Oct 2019 09:18:08 +0000 (18:18 +0900)
setCandidate() of FuseBiasAddPass was revised to support FormerT and LatterT, and removed names related to conv2d.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo/src/Pass/FuseBiasAddPass.cpp

index b5983f1..ab4bc61 100644 (file)
@@ -72,17 +72,6 @@ FormerT *get_former(loco::Node *x, loco::Node *y)
   return nullptr;
 }
 
-// TODO replace this with get_former
-locoex::TFLConv2D *get_conv2d(loco::Node *x, loco::Node *y)
-{
-  if (auto conv2d_node = dynamic_cast<locoex::TFLConv2D *>(x))
-    return conv2d_node;
-  else if (auto conv2d_node = dynamic_cast<locoex::TFLConv2D *>(y))
-    return conv2d_node;
-
-  return nullptr;
-}
-
 /// @brief Finds input that is TFLConst and set it to new_input
 void set_const_input(locoex::TFLNode *node, locoex::TFLConst *new_input)
 {
@@ -204,7 +193,6 @@ template <class LatterT> locoex::TFLConst *Fuser<LatterT>::create_fused_bias_con
 }
 
 // FuseBiasAddPass works when former->fusedActivationFunction() == NONE
-// Note: This method is used for assert(...)
 bool check_act_func(FormerT *former)
 {
   using FusedActFuncMixin = locoex::TFLNodeMixin<locoex::TFLNodeTrait::FusedActFunc>;
@@ -259,7 +247,6 @@ template <class LatterT> void Fuser<LatterT>::fuse(void)
   _latter->y(nullptr);
 }
 
-// TODO rewrite this by using FormerT and LatterT (Remove Conv2D dependency)
 struct Collector final : public locoex::TFLNodeMutableVisitor<void>
 {
   template <class LatterT>
@@ -269,31 +256,26 @@ struct Collector final : public locoex::TFLNodeMutableVisitor<void>
                       std::is_same<LatterT, locoex::TFLSub>::value,
                   "wrong template type");
 
-    // TODO Consider TFLDepthwiseConv2D
-    locoex::TFLConv2D *conv2d_node = dynamic_cast<locoex::TFLConv2D *>(former);
-
-    if (!(const_node && conv2d_node))
-      return;
-
-    if (conv2d_node->fusedActivationFunction() != locoex::FusedActFunc::NONE)
+    if (!check_act_func(former))
       return;
 
-    auto conv2d_depth = loco::shape_get(conv2d_node).as<loco::TensorShape>().dim(3).value();
-    auto const_shape = loco::shape_get(const_node).as<loco::TensorShape>();
+    auto depth =
+        loco::shape_get(as_loco_node(former)).template as<loco::TensorShape>().dim(3).value();
+    auto const_shape = loco::shape_get(const_node).template as<loco::TensorShape>();
 
-    if (const_shape.rank() == 1 and const_shape.dim(0) == conv2d_depth)
+    if (const_shape.rank() == 1 and const_shape.dim(0) == depth)
     {
       candidates.insert(latter);
     }
-    // when Const has only one value, create a new const with shape [depth_of_conv]
+    // when Const has only one value, create a new const with shape [depth]
     else if (const_shape.rank() == 0 or (const_shape.rank() == 1 and const_shape.dim(0) == 1))
     {
-      if (!(loco::dtype_get(conv2d_node) == loco::DataType::FLOAT32))
-        EXO_THROW("unsupported TFLConv2D data type");
+      if (!(loco::dtype_get(as_loco_node(former)) == loco::DataType::FLOAT32))
+        EXO_THROW("unsupported data type");
       if (!(const_node->dtype() == loco::DataType::FLOAT32))
-        EXO_THROW("unsupported TFLConst data type");
+        EXO_THROW("unsupported data type");
 
-      auto new_bias_node = create_widened(const_node, conv2d_depth);
+      auto new_bias_node = create_widened(const_node, depth);
 
       // Replacing TFLConst input of TFLAdd or TFLSub.
       // Note that calling loco::replace(const_node).with(new_bias_node) could be dangerous