From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Wed, 5 Jun 2019 04:32:56 +0000 (+0900) Subject: [moco/tf] Implement fix_shape for ConstGen (#3692) X-Git-Tag: nncc_backup~454 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5df9f0559f42f124bb0ba3eec6ff506d0f8ea90d;p=platform%2Fcore%2Fml%2Fnnfw.git [moco/tf] Implement fix_shape for ConstGen (#3692) This will implement fix_shape in shape inference for ConstGen node Signed-off-by: SaeHie Park --- diff --git a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp index 62bd5cc..2e83ac7 100644 --- a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp +++ b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp @@ -72,8 +72,27 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst) bool fix_shape(loco::ConstGen *node) { - // TODO fill this - return false; + const ShapeInferenceData *shapedata = node->annot(); + if (shapedata != nullptr) + { + // shape inference is already done for ConstGen + return false; + } + + // ConstGen itself has shape information, copy them + auto shape_data = stdex::make_unique(); + uint32_t rank = node->rank(); + shape_data->rank(rank); + for (uint32_t index = 0; index < rank; ++index) + { + if (node->dim(index).known()) + shape_data->dim(index) = loco::make_dimension(node->dim(index).value()); + else + shape_data->dim(index).unset(); + } + node->annot(std::move(shape_data)); + + return true; } bool fix_shape(loco::FeatureDecode *node)