[moco-tf] update stride and window data (#6969)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 28 Aug 2019 01:25:17 +0000 (10:25 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 28 Aug 2019 01:25:17 +0000 (10:25 +0900)
This will introduce update_stride_data() and update_window_data() common function in FixShapeTransform

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 828f928..ce8a61f 100644 (file)
@@ -231,6 +231,42 @@ template <class T> void calc_annot_paddata(T *node, const FixPadContext &ctx)
   assert(node->template annot<PadData>() != nullptr);
 }
 
+template <class T> void update_stride_data(T *node)
+{
+  auto stride_data = stdex::make_unique<StrideData>();
+  auto strides = node->strides();
+  auto data_layout = plier::tf::as_data_layout(node->data_layout());
+  if (data_layout == plier::tf::DataLayout::NHWC)
+  {
+    stride_data->stride()->vertical(strides[1]);
+    stride_data->stride()->horizontal(strides[2]);
+  }
+  else if (data_layout == plier::tf::DataLayout::NCHW)
+  {
+    stride_data->stride()->vertical(strides[2]);
+    stride_data->stride()->horizontal(strides[3]);
+  }
+  node->annot(std::move(stride_data));
+}
+
+template <class T> void update_window_data(T *node)
+{
+  auto window_data = stdex::make_unique<WindowData>();
+  auto ksize = node->ksize();
+  auto data_layout = plier::tf::as_data_layout(node->data_layout());
+  if (data_layout == plier::tf::DataLayout::NHWC)
+  {
+    window_data->window()->vertical(ksize[1]);
+    window_data->window()->horizontal(ksize[2]);
+  }
+  else if (data_layout == plier::tf::DataLayout::NCHW)
+  {
+    window_data->window()->vertical(ksize[2]);
+    window_data->window()->horizontal(ksize[3]);
+  }
+  node->annot(std::move(window_data));
+}
+
 bool fix_shape(loco::AvgPool2D *node)
 {
   LOGGER(l);