[moco-tf] Use precedingOp in FuseBinaryIntoPreceding (#6400)
author남궁석/On-Device Lab(SR)/Engineer/삼성전자 <sk.namkoong@samsung.com>
Fri, 9 Aug 2019 06:06:44 +0000 (15:06 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 9 Aug 2019 06:06:44 +0000 (15:06 +0900)
This commit will introduce precedingOp instead of each operation variable

Signed-off-by: Seok NamKoong <sk.namkoong@samsung.com>
compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp

index 0ff26b2..642ba40 100644 (file)
@@ -296,37 +296,35 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFMul *node)
     return false;
 
   moco::tf::TFConst *mulparam = nullptr;
-  moco::tf::TFConv2D *conv2d = nullptr;
+  moco::tf::TFNode *precedingOp = nullptr;
   // TODO support DepthWiseConv2D
   // TODO support FullyConnected
 
   if (xc != nullptr)
   {
     mulparam = xc;
-    conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->y());
+    precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
   }
   else // yc != nullptr
   {
     mulparam = yc;
-    conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->x());
+    precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
   }
 
-  if (conv2d == nullptr)
-    return false;
-
   assert(mulparam->dtype() == loco::DataType::FLOAT32);
 
-  auto conv2d_fused =
-      fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
+  moco::tf::TFNode *fused_node = nullptr;
+  if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
+    fused_node = fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
 
   // Not ready yet
-  if (conv2d_fused == nullptr)
+  if (fused_node == nullptr)
     return false;
 
   // Replace TFMul node with new precedingOp with fused kernel
   // This will leave existing precedingOp as-is but can be removed if not used
   // from other transformations
-  replace(node).with(conv2d_fused);
+  replace(node).with(fused_node);
   // TODO check if need to disconnect
   // node->x(nullptr);
   // node->y(nullptr);
@@ -420,7 +418,7 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node)
     return false;
 
   moco::tf::TFConst *addparam = nullptr;
-  moco::tf::TFConv2D *conv2d = nullptr;
+  moco::tf::TFNode *precedingOp = nullptr;
   moco::tf::TFBiasAdd *biasadd = nullptr;
   // TODO support DepthWiseConv2D
   // TODO support FullyConnected
@@ -428,14 +426,12 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node)
   if (xc != nullptr)
   {
     addparam = xc;
-    conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->y());
-    biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node->y());
+    precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
   }
   else // yc != nullptr
   {
     addparam = yc;
-    conv2d = dynamic_cast<moco::tf::TFConv2D *>(node->x());
-    biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node->x());
+    precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
   }
 
   auto addparam_shape_inf = addparam->annot<moco::tf::ShapeInferenceData>();
@@ -454,10 +450,14 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node)
     return false;
   }
 
-  if (conv2d != nullptr)
+  if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
   {
     biasadd = create_biasadd_node<moco::tf::TFConv2D>(graph, addparam, conv2d);
   }
+  else if (auto old_bias_add = dynamic_cast<moco::tf::TFBiasAdd *>(precedingOp))
+  {
+    biasadd = old_bias_add;
+  }
 
   if (biasadd == nullptr)
   {