From 779409f560745b84d6bad7e4e443d30093c4d053 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=82=A8=EA=B6=81=EC=84=9D/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 9 Aug 2019 15:06:44 +0900 Subject: [PATCH] [moco-tf] Use precedingOp in FuseBinaryIntoPreceding (#6400) This commit will introduce precedingOp instead of each operation variable Signed-off-by: Seok NamKoong --- .../src/Transforms/FuseBinaryIntoPreceding.cpp | 32 +++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp index 0ff26b2..642ba40 100644 --- a/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp +++ b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp @@ -296,37 +296,35 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFMul *node) return false; moco::tf::TFConst *mulparam = nullptr; - moco::tf::TFConv2D *conv2d = nullptr; + moco::tf::TFNode *precedingOp = nullptr; // TODO support DepthWiseConv2D // TODO support FullyConnected if (xc != nullptr) { mulparam = xc; - conv2d = dynamic_cast(node->y()); + precedingOp = dynamic_cast(node->y()); } else // yc != nullptr { mulparam = yc; - conv2d = dynamic_cast(node->x()); + precedingOp = dynamic_cast(node->x()); } - if (conv2d == nullptr) - return false; - assert(mulparam->dtype() == loco::DataType::FLOAT32); - auto conv2d_fused = - fused_conv_node(graph, mulparam, conv2d); + moco::tf::TFNode *fused_node = nullptr; + if (auto conv2d = dynamic_cast(precedingOp)) + fused_node = fused_conv_node(graph, mulparam, conv2d); // Not ready yet - if (conv2d_fused == nullptr) + if (fused_node == nullptr) return false; // Replace TFMul node with new precedingOp with fused kernel // This will leave existing precedingOp as-is but can be removed if not used // from other transformations - replace(node).with(conv2d_fused); + replace(node).with(fused_node); // TODO check if need to disconnect // node->x(nullptr); // node->y(nullptr); @@ -420,7 +418,7 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node) return false; moco::tf::TFConst *addparam = nullptr; - moco::tf::TFConv2D *conv2d = nullptr; + moco::tf::TFNode *precedingOp = nullptr; moco::tf::TFBiasAdd *biasadd = nullptr; // TODO support DepthWiseConv2D // TODO support FullyConnected @@ -428,14 +426,12 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node) if (xc != nullptr) { addparam = xc; - conv2d = dynamic_cast(node->y()); - biasadd = dynamic_cast(node->y()); + precedingOp = dynamic_cast(node->y()); } else // yc != nullptr { addparam = yc; - conv2d = dynamic_cast(node->x()); - biasadd = dynamic_cast(node->x()); + precedingOp = dynamic_cast(node->x()); } auto addparam_shape_inf = addparam->annot(); @@ -454,10 +450,14 @@ bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node) return false; } - if (conv2d != nullptr) + if (auto conv2d = dynamic_cast(precedingOp)) { biasadd = create_biasadd_node(graph, addparam, conv2d); } + else if (auto old_bias_add = dynamic_cast(precedingOp)) + { + biasadd = old_bias_add; + } if (biasadd == nullptr) { -- 2.7.4