From 56c6d9e22ebce9a905d60292d4af53a34e20126d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 31 Jul 2019 08:39:48 +0900 Subject: [PATCH] [moco-tf] Fix shape for TFMaxPool (#6025) This will implement FixShape for TFMaxPool node Signed-off-by: SaeHie Park --- .../moco-tf/src/Transforms/FixShapeTransform.cpp | 105 ++++++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 72cb3e4..17b821b 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -936,8 +936,109 @@ bool fix_shape(moco::tf::TFIdentity *node) bool fix_shape(moco::tf::TFMaxPool *node) { - // TODO implement - throw std::runtime_error("NYI fix_shape TFMaxPool"); + LOGGER(l); + + auto shapedata = node->annot(); + if (shapedata != nullptr) + { + // shape inference is already done for TFMaxPool + return false; + } + auto value = node->value(); + auto value_shapedata = value->annot(); + if (value_shapedata == nullptr) + { + // input node shape inference is not ready + return false; + } + + auto padding = node->padding(); + assert(padding == "VALID" || padding == "SAME"); + + // TODO move this to some new Transformation... + { + { + auto stride_data = node->annot(); + assert(stride_data == nullptr); + } + auto stride_data = stdex::make_unique(); + auto strides = node->strides(); + auto data_layout = plier::tf::as_data_layout(node->data_layout()); + if (data_layout == plier::tf::DataLayout::NHWC) + { + stride_data->stride()->vertical(strides[1]); + stride_data->stride()->horizontal(strides[2]); + } + else if (data_layout == plier::tf::DataLayout::NCHW) + { + stride_data->stride()->vertical(strides[2]); + stride_data->stride()->horizontal(strides[3]); + } + node->annot(std::move(stride_data)); + + { + auto window_data = node->annot(); + assert(window_data == nullptr); + } + auto window_data = stdex::make_unique(); + auto ksize = node->ksize(); + if (data_layout == plier::tf::DataLayout::NHWC) + { + window_data->window()->vertical(ksize[1]); + window_data->window()->horizontal(ksize[2]); + } + else if (data_layout == plier::tf::DataLayout::NCHW) + { + window_data->window()->vertical(ksize[2]); + window_data->window()->horizontal(ksize[3]); + } + node->annot(std::move(window_data)); + } + + auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout()); + + auto stride_data = node->annot(); + assert(stride_data != nullptr); + auto window_data = node->annot(); + assert(window_data != nullptr); + + uint32_t input_height = value_feature_shape.height().value(); + uint32_t input_width = value_feature_shape.width().value(); + uint32_t stride_height = stride_data->stride()->vertical(); + uint32_t stride_width = stride_data->stride()->horizontal(); + uint32_t window_height = window_data->window()->vertical(); + uint32_t window_width = window_data->window()->horizontal(); + uint32_t dilation_height = 1; // dilation for MaxPool is 1 + uint32_t dilation_width = 1; + uint32_t effective_window_height = dilation_height * (window_height - 1) + 1; + uint32_t effective_window_width = dilation_width * (window_width - 1) + 1; + uint32_t output_height; + uint32_t output_width; + + if (padding == "VALID") + { + output_height = (input_height + stride_height - effective_window_height) / stride_height; + output_width = (input_width + stride_width - effective_window_width) / stride_width; + } + else if (padding == "SAME") + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + + loco::FeatureShape ofm_feature_shape; + ofm_feature_shape.count() = value_feature_shape.count(); + ofm_feature_shape.height() = output_height; + ofm_feature_shape.width() = output_width; + ofm_feature_shape.depth() = value_feature_shape.depth(); + + auto shape_data = stdex::make_unique(); + as_tensor_shape(*shape_data.get(), ofm_feature_shape, node->data_layout()); + node->annot(std::move(shape_data)); + + INFO(l) << "Fix TFMaxPool shape = ifm" << value_feature_shape << " --> ofm" << ofm_feature_shape; + + return true; } bool fix_shape(moco::tf::TFMul *node) -- 2.7.4