From 5df9f0559f42f124bb0ba3eec6ff506d0f8ea90d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 5 Jun 2019 13:32:56 +0900 Subject: [PATCH] [moco/tf] Implement fix_shape for ConstGen (#3692) This will implement fix_shape in shape inference for ConstGen node Signed-off-by: SaeHie Park --- .../tf/src/Transforms/FixShapeTransform.cpp | 23 ++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp index 62bd5cc..2e83ac7 100644 --- a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp +++ b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp @@ -72,8 +72,27 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst) bool fix_shape(loco::ConstGen *node) { - // TODO fill this - return false; + const ShapeInferenceData *shapedata = node->annot(); + if (shapedata != nullptr) + { + // shape inference is already done for ConstGen + return false; + } + + // ConstGen itself has shape information, copy them + auto shape_data = stdex::make_unique(); + uint32_t rank = node->rank(); + shape_data->rank(rank); + for (uint32_t index = 0; index < rank; ++index) + { + if (node->dim(index).known()) + shape_data->dim(index) = loco::make_dimension(node->dim(index).value()); + else + shape_data->dim(index).unset(); + } + node->annot(std::move(shape_data)); + + return true; } bool fix_shape(loco::FeatureDecode *node) -- 2.7.4