From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Tue, 16 Jul 2019 06:17:57 +0000 (+0900) Subject: [moco/tf] Use TFConst for Resolve FBN (#4291) X-Git-Tag: nncc_backup~50 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0fdbb820e595283a1f5e620ca161a316f04aee2b;p=platform%2Fcore%2Fml%2Fnnfw.git [moco/tf] Use TFConst for Resolve FBN (#4291) This will change ResolveFusedBatchNorm transformation to use TFConst instead of ConstGen as an inputs Signed-off-by: SaeHie Park --- diff --git a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp index 827a89f..37074af 100644 --- a/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp +++ b/contrib/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp @@ -17,6 +17,7 @@ #include "ResolveFusedBatchNorm.h" #include "IR/TFAdd.h" +#include "IR/TFConst.h" #include "IR/TFMul.h" #include "Convert.h" @@ -33,7 +34,7 @@ namespace { -bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc) +bool is_same_shape(moco::tf::TFConst *lc, moco::tf::TFConst *rc) { if (lc->rank() != rc->rank()) return false; @@ -46,7 +47,7 @@ bool is_same_shape(loco::ConstGen *lc, loco::ConstGen *rc) return true; } -void copy_shape(const loco::ConstGen *src, loco::ConstGen *dst) +void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst) { assert(src != nullptr); assert(dst != nullptr); @@ -102,10 +103,10 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) return false; } - auto tffbn_gamma = dynamic_cast(node->gamma()); - auto tffbn_beta = dynamic_cast(node->beta()); - auto tffbn_mean = dynamic_cast(node->mean()); - auto tffbn_variance = dynamic_cast(node->variance()); + auto tffbn_gamma = dynamic_cast(node->gamma()); + auto tffbn_beta = dynamic_cast(node->beta()); + auto tffbn_mean = dynamic_cast(node->mean()); + auto tffbn_variance = dynamic_cast(node->variance()); // all should be const if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr || @@ -189,7 +190,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param) * %21:fbn_add_param = ConstGen(fbn_add_param) */ - auto const_fbn_mul_0_param = graph->nodes()->create(); + auto const_fbn_mul_0_param = graph->nodes()->create(); const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32); copy_shape(tffbn_gamma, const_fbn_mul_0_param); const_fbn_mul_0_param->size(const_num_elements); @@ -197,7 +198,7 @@ bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node) { const_fbn_mul_0_param->at(i) = fbn_mul_0_param.get()[i]; } - auto const_fbn_add_param = graph->nodes()->create(); + auto const_fbn_add_param = graph->nodes()->create(); const_fbn_add_param->dtype(loco::DataType::FLOAT32); copy_shape(tffbn_gamma, const_fbn_add_param); const_fbn_add_param->size(const_num_elements);