From: 박천교/On-Device Lab(SR)/Engineer/삼성전자 Date: Fri, 2 Aug 2019 08:11:07 +0000 (+0900) Subject: [moco-tf] TFSqueeze shape inference (#6117) X-Git-Tag: submit/tizen/20190809.050447~214 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1da0c6d699d5efba250fc2e42aec8c3af2072ff6;p=platform%2Fcore%2Fml%2Fnnfw.git [moco-tf] TFSqueeze shape inference (#6117) * [moco-tf] TFSqueeze shape inference This commit introduces shape inference for TFSqueeze Signed-off-by: Cheongyo Bahk * Review fix: use set, use lambda for assertion --- diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 11e330d..9fa72ae 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -1148,8 +1148,96 @@ bool fix_shape(moco::tf::TFRsqrt *node) bool fix_shape(moco::tf::TFSqueeze *node) { - // TODO implement - throw std::runtime_error("NYI fix_shape TFSqueeze"); + auto shapedata = node->annot(); + if (shapedata != nullptr) + { + // shape inference is already done for TFSqueeze + return false; + } + + auto input = node->input(); + auto input_shape = input->annot(); + if (input_shape == nullptr) + { + // Input shape is required for TFSqueeze shape inference + return false; + } + + // TODO Not sure Squeeze only get input as Tensor + // Note that tensor_shape() has assertion in it + auto input_tensor_shape = input_shape->tensor_shape(); + + auto squeeze_dims_vec = node->squeeze_dims(); + const std::set squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend()); + + loco::TensorShape node_shape; + uint32_t node_rank = 0; + + if (squeeze_dims.empty()) + { + // Remove all dimensions whose value is 1 + for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis) + { + assert(input_tensor_shape.dim(axis).known()); + auto dim = input_tensor_shape.dim(axis).value(); + if (dim != 1) + { + assert(dim > 1); + node_shape.rank(++node_rank); + node_shape.dim(node_rank - 1) = dim; + } + } + } + else + { + uint32_t input_rank = input_tensor_shape.rank(); + + auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() { + if (squeeze_dims.size() >= input_rank) + return false; + for (auto squeeze_dim : squeeze_dims) + { + // Negative squeeze dimensions should be resolve before + if (squeeze_dim < 0) + return false; + if (squeeze_dim >= (int64_t)input_rank) + return false; + } + return true; + }; + + assert(is_valid_squeeze_dims()); + + // Remove squeeze dimensions only + for (uint32_t axis = 0; axis < input_rank; ++axis) + { + assert(input_tensor_shape.dim(axis).known()); + auto dim = input_tensor_shape.dim(axis).value(); + if (squeeze_dims.find((int64_t)axis) == squeeze_dims.cend()) + { + // Not squeeze dim + node_shape.rank(++node_rank); + node_shape.dim(node_rank - 1) = dim; + } + else + { + // Is squeeze dim + assert(dim == 1); + // DO NOTHING + } + } + } + + assert(node_shape.rank() > 0); + + auto shape_annot = stdex::make_unique(); + shape_annot->tensor_shape(node_shape); + node->annot(std::move(shape_annot)); + + LOGGER(l); + INFO(l) << "Fix TFSqueeze shape = " << node_shape; + + return true; } } // namespace