[moco/tf] Use TFConst for Resolve FBN (#4291)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 16 Jul 2019 06:17:57 +0000 (15:17 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 16 Jul 2019 06:17:57 +0000 (15:17 +0900)
This will change ResolveFusedBatchNorm transformation to use TFConst instead of ConstGen as an inputs

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp

index 827a89f..37074af 100644 (file)
@@ -17,6 +17,7 @@
 #include "ResolveFusedBatchNorm.h"
 
 #include "IR/TFAdd.h"
+#include "IR/TFConst.h"
 #include "IR/TFMul.h"
 
 #include "Convert.h"
@@ -33,7 +34,7 @@
 namespace
 {
 
-bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc)
+bool is_same_shape(moco::tf::TFConst *lc, moco::tf::TFConst *rc)
 {
   if (lc->rank() != rc->rank())
     return false;
@@ -46,7 +47,7 @@ bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc)
   return true;
 }
 
-void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst)
+void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst)
 {
   assert(src != nullptr);
   assert(dst != nullptr);
@@ -102,10 +103,10 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
     return false;
   }
 
-  auto tffbn_gamma = dynamic_cast<loco::ConstGen *>(node->gamma());
-  auto tffbn_beta = dynamic_cast<loco::ConstGen *>(node->beta());
-  auto tffbn_mean = dynamic_cast<loco::ConstGen *>(node->mean());
-  auto tffbn_variance = dynamic_cast<loco::ConstGen *>(node->variance());
+  auto tffbn_gamma = dynamic_cast<moco::tf::TFConst *>(node->gamma());
+  auto tffbn_beta = dynamic_cast<moco::tf::TFConst *>(node->beta());
+  auto tffbn_mean = dynamic_cast<moco::tf::TFConst *>(node->mean());
+  auto tffbn_variance = dynamic_cast<moco::tf::TFConst *>(node->variance());
 
   // all should be const
   if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr ||
@@ -189,7 +190,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
    * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
    * %21:fbn_add_param = ConstGen(fbn_add_param)
    */
-  auto const_fbn_mul_0_param = graph->nodes()->create<loco::ConstGen>();
+  auto const_fbn_mul_0_param = graph->nodes()->create<moco::tf::TFConst>();
   const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32);
   copy_shape(tffbn_gamma, const_fbn_mul_0_param);
   const_fbn_mul_0_param->size<loco::DataType::FLOAT32>(const_num_elements);
@@ -197,7 +198,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
   {
     const_fbn_mul_0_param->at<loco::DataType::FLOAT32>(i) = fbn_mul_0_param.get()[i];
   }
-  auto const_fbn_add_param = graph->nodes()->create<loco::ConstGen>();
+  auto const_fbn_add_param = graph->nodes()->create<moco::tf::TFConst>();
   const_fbn_add_param->dtype(loco::DataType::FLOAT32);
   copy_shape(tffbn_gamma, const_fbn_add_param);
   const_fbn_add_param->size<loco::DataType::FLOAT32>(const_num_elements);