From 56f70f70883c0ed1a606fb851d2f99dde5c247dc Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 28 Aug 2019 10:25:17 +0900 Subject: [PATCH] [moco-tf] update stride and window data (#6969) This will introduce update_stride_data() and update_window_data() common function in FixShapeTransform Signed-off-by: SaeHie Park --- .../moco-tf/src/Transforms/FixShapeTransform.cpp | 36 ++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 828f928..ce8a61f 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -231,6 +231,42 @@ template void calc_annot_paddata(T *node, const FixPadContext &ctx) assert(node->template annot() != nullptr); } +template void update_stride_data(T *node) +{ + auto stride_data = stdex::make_unique(); + auto strides = node->strides(); + auto data_layout = plier::tf::as_data_layout(node->data_layout()); + if (data_layout == plier::tf::DataLayout::NHWC) + { + stride_data->stride()->vertical(strides[1]); + stride_data->stride()->horizontal(strides[2]); + } + else if (data_layout == plier::tf::DataLayout::NCHW) + { + stride_data->stride()->vertical(strides[2]); + stride_data->stride()->horizontal(strides[3]); + } + node->annot(std::move(stride_data)); +} + +template void update_window_data(T *node) +{ + auto window_data = stdex::make_unique(); + auto ksize = node->ksize(); + auto data_layout = plier::tf::as_data_layout(node->data_layout()); + if (data_layout == plier::tf::DataLayout::NHWC) + { + window_data->window()->vertical(ksize[1]); + window_data->window()->horizontal(ksize[2]); + } + else if (data_layout == plier::tf::DataLayout::NCHW) + { + window_data->window()->vertical(ksize[2]); + window_data->window()->horizontal(ksize[3]); + } + node->annot(std::move(window_data)); +} + bool fix_shape(loco::AvgPool2D *node) { LOGGER(l); -- 2.7.4