From b8feb7088533e164d6d84cb970f05fb0c6aef34f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 13 Aug 2019 13:52:01 +0900 Subject: [PATCH] [moco-tf] Clear ConcatV2 fix shape (#6523) This will erase ConcatV2 fix shape to reduce complexity of diff for migration to using VariadicArityNode Signed-off-by: SaeHie Park --- .../moco-tf/src/Transforms/FixShapeTransform.cpp | 61 ++-------------------- 1 file changed, 3 insertions(+), 58 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index d3f45b2..fc89e3d 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -790,66 +790,11 @@ bool fix_shape(moco::tf::TFBiasAdd *node) bool fix_shape(moco::tf::TFConcatV2 *node) { - auto concat_data = node->annot(); - if (concat_data == nullptr) - { - // shape inference is already done for TFConcatV2 - assert(node->annot() != nullptr); - return false; - } - assert(node->annot() == nullptr); + (void)node; - auto lhs = node->lhs(); - auto rhs = node->rhs(); - auto lhs_shapedata = lhs->annot(); - auto rhs_shapedata = rhs->annot(); - if (lhs_shapedata == nullptr || rhs_shapedata == nullptr) - { - // postpone as previous input node(s) hasn't been processed. - // this will return false as there was nothing changed, but the input of this - // node should be changed and from that this method should be called again. - // if not, the network may have some problem and the final output may not have - // the right shape value and we can identify with some validation at final stage. - return false; - } - - uint32_t lhs_rank = lhs_shapedata->rank(); - uint32_t rhs_rank = rhs_shapedata->rank(); - assert(lhs_rank == rhs_rank); + throw std::runtime_error("NYI fix_shape TFConcatV2"); - int32_t axis_tf = concat_data->axis(); - if (axis_tf < 0) - { - axis_tf = static_cast(lhs_rank) + axis_tf; - } - assert(0 <= axis_tf && axis_tf < static_cast(lhs_rank)); - // clear annotation ConcatData - node->annot(nullptr); - - uint32_t axis_loco = static_cast(axis_tf); - node->axis(axis_loco); - - // Set ShapeInferenceData for TensorConcat - auto shape_data = stdex::make_unique(); - shape_data->rank(lhs_rank); - for (uint32_t index = 0; index < lhs_rank; ++index) - { - uint32_t lhs_dim = lhs_shapedata->dim(index).value(); - uint32_t rhs_dim = rhs_shapedata->dim(index).value(); - // "lhs_dim == rhs_dim" should hold when "index != axis_loco" - // or doesn't care when "index == axis_loco" - assert(index == axis_loco || lhs_dim == rhs_dim); - - uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim; - - if (lhs_shapedata->dim(index).known()) - shape_data->dim(index) = new_dim; - else - shape_data->dim(index).unset(); - } - node->annot(std::move(shape_data)); - - return true; + return false; } bool fix_shape(moco::tf::TFConst *node) -- 2.7.4