From: 박천교/On-Device Lab(SR)/Engineer/삼성전자 Date: Sun, 28 Jul 2019 23:52:34 +0000 (+0900) Subject: [moco-tf] Implement shape inference for TFReshape (#5914) X-Git-Tag: submit/tizen/20190809.050447~375 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3ca2b41d4f72e76a0e8df7a1cae0018189235fcf;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] Implement shape inference for TFReshape (#5914) This commit implements shape inference for TFReshape on fix_shape(). Signed-off-by: Cheongyo Bahk --- diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 87a2c9e..37cec53 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -745,8 +745,56 @@ bool fix_shape(moco::tf::TFMul *node) bool fix_shape(moco::tf::TFReshape *node) { - // TODO implement - throw std::runtime_error("NYI fix_shape TFReshape"); + auto shapedata = node->annot(); + if (shapedata != nullptr) + { + // shape inference is already done for TFReshape + return false; + } + + // For now, we only consider Fixed Reshape, i.e. Reshape with determined + // 'shape' input. So here we only support case when 'shape' input of + // TFReshape is TFConst. If 'shape' input is not TFConst, another + // transform (e.g. constant folding) should be done beforehand to make + // it TFConst. + // TODO Support dynamic Reshape + // Note that 'shape()' here is 'shape' input, not node's shape information + auto const_shape_input = dynamic_cast(node->shape()); + if (!const_shape_input) + { + // 'shape' input of TFReshape is not TFConst, try next time when it becomes TFConst + return false; + } + + // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1] + assert(const_shape_input->dtype() == loco::DataType::S32); + assert(const_shape_input->rank() == 1); + + auto shape_rank = const_shape_input->dim(0).value(); + assert(shape_rank > 0); + + loco::TensorShape shape_data; + shape_data.rank(shape_rank); + for (uint32_t axis = 0; axis < shape_rank; ++axis) + { + shape_data.dim(axis) = const_shape_input->at(axis); + } + + // TODO Compare 'tensor' input and validate coherency? + // Not sure this is appropriate stage for this task. + + auto shape_annot = stdex::make_unique(); + shape_annot->tensor_shape(shape_data); + node->annot(std::move(shape_annot)); + + { + LOGGER(l); + auto shapedata = node->annot(); + assert(shapedata != nullptr); + INFO(l) << "Fix TFReshape shape = " << shapedata->tensor_shape(); + } + + return true; } bool fix_shape(moco::tf::TFRsqrt *node)