From 4da31c518732540cd43cb14de74fed1f5b2bcdcf Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 11 Jul 2019 12:47:45 +0900 Subject: [PATCH] [moco/tf] Use PadData and StrideData for TFConv2d (#4163) * [moco/tf] Use PadData and StrideData for TFConv2d This will fix TFConv2D related FixPadding and FixShare transformations to use PadData and ShapeData to simplify Canonicalization Signed-off-by: SaeHie Park * check data_layout * fix log message * fix log message --- .../moco-tf/src/Transforms/FixPaddingTransform.cpp | 101 ++++++++++++++++++++- .../moco-tf/src/Transforms/FixShapeTransform.cpp | 40 +++++--- 2 files changed, 126 insertions(+), 15 deletions(-) diff --git a/contrib/moco-tf/src/Transforms/FixPaddingTransform.cpp b/contrib/moco-tf/src/Transforms/FixPaddingTransform.cpp index 9bba380..02ee8ab 100644 --- a/contrib/moco-tf/src/Transforms/FixPaddingTransform.cpp +++ b/contrib/moco-tf/src/Transforms/FixPaddingTransform.cpp @@ -18,11 +18,14 @@ #include "Convert.h" #include "Annotations/PaddingData.h" +#include "Annotations/PadData.h" #include "Annotations/ShapeInferenceData.h" +#include "Annotations/StrideData.h" #include "Dialect/TFNodes.h" #include #include +#include #include #include @@ -332,8 +335,102 @@ bool fix_padding(moco::tf::TFBiasAdd *node) bool fix_padding(moco::tf::TFConv2D *node) { - // Nothing to do with padding - return false; + LOGGER(l); + + auto pad_data_c = node->annot(); + if (pad_data_c != nullptr) + { + // padding conversion is already done + return false; + } + + auto stride_data = node->annot(); + if (stride_data == nullptr) + { + // need stride data but not ready yet + return false; + } + + auto ofm_shapedata = node->annot(); + if (ofm_shapedata == nullptr) + { + // need output shape to calculate padding values + return false; + } + + auto ifm = node->ifm(); + assert(ifm != nullptr); + auto ifm_shapedata = ifm->annot(); + if (ifm_shapedata == nullptr) + { + // need input shape to calculate padding values + return false; + } + + auto ker = node->ker(); + assert(ker != nullptr); + auto ker_shapedata = ker->annot(); + if (ker_shapedata == nullptr) + { + return false; + } + + auto padding = node->padding(); + assert(padding == "VALID" || padding == "SAME"); + + auto data_layout = node->data_layout(); + assert(data_layout == "NHWC"); + + auto ifm_tensor_shape = ifm_shapedata->tensor_shape(); // in NHWC + auto ker_tensor_shape = ker_shapedata->tensor_shape(); // in HWIO + auto ofm_tensor_shape = ofm_shapedata->tensor_shape(); // in NHWC + assert(ifm_tensor_shape.rank() == 4); + assert(ker_tensor_shape.rank() == 4); + assert(ofm_tensor_shape.rank() == 4); + + uint32_t input_height = ifm_tensor_shape.dim(1).value(); + uint32_t input_width = ifm_tensor_shape.dim(2).value(); + uint32_t stride_height = stride_data->stride()->vertical(); + uint32_t stride_width = stride_data->stride()->horizontal(); + uint32_t ker_height = ker_tensor_shape.dim(0).value(); + uint32_t ker_width = ker_tensor_shape.dim(1).value(); + uint32_t output_height = ofm_tensor_shape.dim(1).value(); + uint32_t output_width = ofm_tensor_shape.dim(2).value(); + + uint32_t dilation_height = 1; // TODO Consider dilation + uint32_t dilation_width = 1; + uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1; + uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1; + // calculate padding height, width + int32_t i_height = (output_height - 1) * stride_height + effective_ker_height - input_height; + int32_t i_width = (output_width - 1) * stride_width + effective_ker_width - input_width; + uint32_t height = i_height >= 0 ? i_height : 0U; + uint32_t width = i_width >= 0 ? i_width : 0U; + + // annotation of pad data + auto pad_data = stdex::make_unique(); + + pad_data->pad()->top(height / 2); + pad_data->pad()->bottom(height - pad_data->pad()->top()); + pad_data->pad()->left(width / 2); + pad_data->pad()->right(width - pad_data->pad()->left()); + + node->annot(std::move(pad_data)); + + { + auto pad_data = node->annot(); + assert(pad_data != nullptr); + + // clang-format off + INFO(l) << "Fix TFConv2D pad " + << "= T " << pad_data->pad()->top() + << ", L " << pad_data->pad()->left() + << ", B " << pad_data->pad()->bottom() + << ", R " << pad_data->pad()->right() << std::endl; + // clang-format on + } + + return true; } bool fix_padding(moco::tf::TFFusedBatchNorm *node) diff --git a/contrib/moco-tf/src/Transforms/FixShapeTransform.cpp b/contrib/moco-tf/src/Transforms/FixShapeTransform.cpp index d5a7957..62a790f 100644 --- a/contrib/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/contrib/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -19,6 +19,7 @@ #include "Annotations/ConcatData.h" #include "Annotations/PaddingData.h" #include "Annotations/ShapeInferenceData.h" +#include "Annotations/StrideData.h" #include "Dialect/TFNodes.h" #include @@ -604,22 +605,35 @@ bool fix_shape(moco::tf::TFConv2D *node) } auto padding = node->padding(); + assert(padding == "VALID" || padding == "SAME"); // TODO move this to some new Transformation... - auto strides = node->strides(); - auto data_layout = as_DataLayout(node->data_layout()); - loco::Stride<2> stride; - if (data_layout == DataLayout::NHWC) { - stride.vertical(strides[1]); - stride.horizontal(strides[2]); - } - else if (data_layout == DataLayout::NCHW) - { - stride.vertical(strides[2]); - stride.horizontal(strides[3]); + { + auto stride_data = node->annot(); + assert(stride_data == nullptr); + } + auto stride_data = stdex::make_unique(); + auto strides = node->strides(); + auto data_layout = as_DataLayout(node->data_layout()); + if (data_layout == DataLayout::NHWC) + { + stride_data->stride()->vertical(strides[1]); + stride_data->stride()->horizontal(strides[2]); + } + else if (data_layout == DataLayout::NCHW) + { + stride_data->stride()->vertical(strides[2]); + stride_data->stride()->horizontal(strides[3]); + } + node->annot(std::move(stride_data)); } + auto stride_data = node->annot(); + assert(stride_data != nullptr); + INFO(l) << "FixShape TFConv2D strides = " << stride_data->stride()->vertical() << ", " + << stride_data->stride()->horizontal(); + auto ifm_tensor_shape = ifm_shapedata->tensor_shape(); // in NHWC auto ker_tensor_shape = ker_shapedata->tensor_shape(); // in HWIO assert(ifm_tensor_shape.rank() == 4); @@ -627,8 +641,8 @@ bool fix_shape(moco::tf::TFConv2D *node) uint32_t input_height = ifm_tensor_shape.dim(1).value(); uint32_t input_width = ifm_tensor_shape.dim(2).value(); - uint32_t stride_height = stride.vertical(); - uint32_t stride_width = stride.horizontal(); + uint32_t stride_height = stride_data->stride()->vertical(); + uint32_t stride_width = stride_data->stride()->horizontal(); uint32_t ker_height = ker_tensor_shape.dim(0).value(); uint32_t ker_width = ker_tensor_shape.dim(1).value(); uint32_t dilation_height = 1; // TODO Consider dilation -- 2.7.4