From 0fdbb820e595283a1f5e620ca161a316f04aee2b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 16 Jul 2019 15:17:57 +0900 Subject: [PATCH] [moco/tf] Use TFConst for Resolve FBN (#4291) This will change ResolveFusedBatchNorm transformation to use TFConst instead of ConstGen as an inputs Signed-off-by: SaeHie Park --- .../moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp index 827a89f..37074af 100644 --- a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp +++ b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp @@ -17,6 +17,7 @@ #include "ResolveFusedBatchNorm.h" #include "IR/TFAdd.h" +#include "IR/TFConst.h" #include "IR/TFMul.h" #include "Convert.h" @@ -33,7 +34,7 @@ namespace { -bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc) +bool is_same_shape(moco::tf::TFConst *lc, moco::tf::TFConst *rc) { if (lc->rank() != rc->rank()) return false; @@ -46,7 +47,7 @@ bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc) return true; } -void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst) +void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst) { assert(src != nullptr); assert(dst != nullptr); @@ -102,10 +103,10 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) return false; } - auto tffbn_gamma = dynamic_cast(node->gamma()); - auto tffbn_beta = dynamic_cast(node->beta()); - auto tffbn_mean = dynamic_cast(node->mean()); - auto tffbn_variance = dynamic_cast(node->variance()); + auto tffbn_gamma = dynamic_cast(node->gamma()); + auto tffbn_beta = dynamic_cast(node->beta()); + auto tffbn_mean = dynamic_cast(node->mean()); + auto tffbn_variance = dynamic_cast(node->variance()); // all should be const if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr || @@ -189,7 +190,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param) * %21:fbn_add_param = ConstGen(fbn_add_param) */ - auto const_fbn_mul_0_param = graph->nodes()->create(); + auto const_fbn_mul_0_param = graph->nodes()->create(); const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32); copy_shape(tffbn_gamma, const_fbn_mul_0_param); const_fbn_mul_0_param->size(const_num_elements); @@ -197,7 +198,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) { const_fbn_mul_0_param->at(i) = fbn_mul_0_param.get()[i]; } - auto const_fbn_add_param = graph->nodes()->create(); + auto const_fbn_add_param = graph->nodes()->create(); const_fbn_add_param->dtype(loco::DataType::FLOAT32); copy_shape(tffbn_gamma, const_fbn_add_param); const_fbn_add_param->size(const_num_elements); -- 2.7.4