From 2317062bbf20c751983f36fdc82e90316d77e82a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 9 Sep 2019 19:14:25 +0900 Subject: [PATCH] [moco-tf] Apply shape_inference_done (#7304) This will apply to use shape_inference_done() for fix_shape nodes Signed-off-by: SaeHie Park --- .../moco-tf/src/Transforms/FixShapeTransform.cpp | 61 +++++----------------- 1 file changed, 14 insertions(+), 47 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index c7cab16..9364391 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -337,12 +337,8 @@ template void update_window_data(T *node) bool fix_shape(loco::Pull *node) { - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for Pull + if (shape_inference_done(node)) return false; - } // Pull itself has shape information, copy them auto shape_data = make_shape_inference_data(node); @@ -379,12 +375,9 @@ bool fix_shape(moco::tf::TFAvgPool *node) { LOGGER(l); - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFAvgPool + if (shape_inference_done(node)) return false; - } + auto value = node->value(); loco::NodeShape value_shape; if (!node_shape(value, value_shape)) @@ -513,9 +506,8 @@ bool fix_shape(moco::tf::TFConcatV2 *node) { LOGGER(l); - if (node->annot() != nullptr) + if (shape_inference_done(node)) { - // shape inference is already done for TFConcatV2 INFO(l) << "Fix shape TFConcatV2 already done"; return false; } @@ -638,12 +630,8 @@ bool fix_shape(moco::tf::TFConcatV2 *node) bool fix_shape(moco::tf::TFConst *node) { - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFConst + if (shape_inference_done(node)) return false; - } // TFConst itself has shape information, copy them auto shape_data = make_shape_inference_data(node); @@ -663,12 +651,9 @@ bool fix_shape(moco::tf::TFConv2D *node) { LOGGER(l); - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done + if (shape_inference_done(node)) return false; - } + auto ifm = node->ifm(); loco::NodeShape ifm_shape; if (!node_shape(ifm, ifm_shape)) @@ -755,12 +740,9 @@ bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node) { LOGGER(l); - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done + if (shape_inference_done(node)) return false; - } + auto ifm = node->ifm(); loco::NodeShape ifm_shape; if (!node_shape(ifm, ifm_shape)) @@ -858,12 +840,9 @@ bool fix_shape(moco::tf::TFMaxPool *node) { LOGGER(l); - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFMaxPool + if (shape_inference_done(node)) return false; - } + auto value = node->value(); loco::NodeShape value_shape; if (!node_shape(value, value_shape)) @@ -982,12 +961,8 @@ bool fix_shape(moco::tf::TFRelu6 *node) bool fix_shape(moco::tf::TFReshape *node) { - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFReshape + if (shape_inference_done(node)) return false; - } // For now, we only consider Fixed Reshape, i.e. Reshape with determined // 'shape' input. So here we only support case when 'shape' input of @@ -1050,12 +1025,8 @@ bool fix_shape(moco::tf::TFRsqrt *node) bool fix_shape(moco::tf::TFShape *node) { - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFShape + if (shape_inference_done(node)) return false; - } auto input = node->input(); loco::NodeShape input_shape; @@ -1105,12 +1076,8 @@ bool fix_shape(moco::tf::TFSquaredDifference *node) bool fix_shape(moco::tf::TFSqueeze *node) { - auto shapedata = node->annot(); - if (shapedata != nullptr) - { - // shape inference is already done for TFSqueeze + if (shape_inference_done(node)) return false; - } auto input = node->input(); loco::NodeShape input_shape; -- 2.7.4