From cfd82493182f4763083faa15a29341823650a921 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 17 Jun 2019 16:23:50 +0900 Subject: [PATCH] [moco/tf] Fix shape transformation for AvgPool2D (#3812) This will fix shape transformation formula for AvgPool2D with padding type Signed-off-by: SaeHie Park --- .../tf/src/Transforms/FixShapeTransform.cpp | 23 ++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp index f0fca60..bd8c29e 100644 --- a/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp +++ b/contrib/moco/lib/frontend/tf/src/Transforms/FixShapeTransform.cpp @@ -111,8 +111,27 @@ bool fix_shape(loco::AvgPool2D *node) uint32_t input_width = ifm_feature_shape.width().value(); uint32_t stride_height = node->stride()->vertical(); uint32_t stride_width = node->stride()->horizontal(); - uint32_t output_height = (input_height + stride_height - 1) / stride_height; - uint32_t output_width = (input_width + stride_width - 1) / stride_width; + uint32_t window_height = node->window()->vertical(); + uint32_t window_width = node->window()->horizontal(); + uint32_t dilation_height = 1; // dilation for AvgPool2D is 1 + uint32_t dilation_width = 1; + uint32_t effective_window_height = dilation_height * (window_height - 1) + 1; + uint32_t effective_window_width = dilation_width * (window_width - 1) + 1; + uint32_t output_height; + uint32_t output_width; + + if (padding_data->padding() == "VALID") + { + output_height = (input_height + stride_height - effective_window_height) / stride_height; + output_width = (input_width + stride_width - effective_window_width) / stride_width; + } + else if (padding_data->padding() == "SAME") + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + throw std::runtime_error("Not supported padding type in FixShapeTransform AvgPool2D"); loco::FeatureShape ofm_feature_shape; ofm_feature_shape.count() = ifm_feature_shape.count(); -- 2.7.4