[mogo-tf] use update_stride_data and update_window_data (#6987)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 28 Aug 2019 07:25:29 +0000 (16:25 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 28 Aug 2019 07:25:29 +0000 (16:25 +0900)
This will update to use update_stride_data() and update_window_data() in fix_shape() functions

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 2d6355b..7467b01 100644 (file)
@@ -868,45 +868,8 @@ bool fix_shape(moco::tf::TFAvgPool *node)
   auto padding = node->padding();
   assert(padding == "VALID" || padding == "SAME");
 
-  // TODO move this to some new Transformation...
-  {
-    {
-      auto stride_data = node->annot<StrideData>();
-      assert(stride_data == nullptr);
-    }
-    auto stride_data = stdex::make_unique<StrideData>();
-    auto strides = node->strides();
-    auto data_layout = plier::tf::as_data_layout(node->data_layout());
-    if (data_layout == plier::tf::DataLayout::NHWC)
-    {
-      stride_data->stride()->vertical(strides[1]);
-      stride_data->stride()->horizontal(strides[2]);
-    }
-    else if (data_layout == plier::tf::DataLayout::NCHW)
-    {
-      stride_data->stride()->vertical(strides[2]);
-      stride_data->stride()->horizontal(strides[3]);
-    }
-    node->annot(std::move(stride_data));
-
-    {
-      auto window_data = node->annot<WindowData>();
-      assert(window_data == nullptr);
-    }
-    auto window_data = stdex::make_unique<WindowData>();
-    auto ksize = node->ksize();
-    if (data_layout == plier::tf::DataLayout::NHWC)
-    {
-      window_data->window()->vertical(ksize[1]);
-      window_data->window()->horizontal(ksize[2]);
-    }
-    else if (data_layout == plier::tf::DataLayout::NCHW)
-    {
-      window_data->window()->vertical(ksize[2]);
-      window_data->window()->horizontal(ksize[3]);
-    }
-    node->annot(std::move(window_data));
-  }
+  update_stride_data(node);
+  update_window_data(node);
 
   auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());
 
@@ -1184,27 +1147,7 @@ bool fix_shape(moco::tf::TFConv2D *node)
   auto padding = node->padding();
   assert(padding == "VALID" || padding == "SAME");
 
-  // TODO move this to some new Transformation...
-  {
-    {
-      auto stride_data = node->annot<StrideData>();
-      assert(stride_data == nullptr);
-    }
-    auto stride_data = stdex::make_unique<StrideData>();
-    auto strides = node->strides();
-    auto data_layout = plier::tf::as_data_layout(node->data_layout());
-    if (data_layout == plier::tf::DataLayout::NHWC)
-    {
-      stride_data->stride()->vertical(strides[1]);
-      stride_data->stride()->horizontal(strides[2]);
-    }
-    else if (data_layout == plier::tf::DataLayout::NCHW)
-    {
-      stride_data->stride()->vertical(strides[2]);
-      stride_data->stride()->horizontal(strides[3]);
-    }
-    node->annot(std::move(stride_data));
-  }
+  update_stride_data(node);
 
   auto stride_data = node->annot<StrideData>();
   assert(stride_data != nullptr);
@@ -1293,29 +1236,7 @@ bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node)
     return false;
   }
 
-  {
-    auto stride_data = node->annot<StrideData>();
-    assert(stride_data == nullptr);
-  }
-
-  auto stride_copy = stdex::make_unique<StrideData>();
-  auto strides = node->strides();
-  auto data_layout = plier::tf::as_data_layout(node->data_layout());
-  if (data_layout == plier::tf::DataLayout::NHWC)
-  {
-    stride_copy->stride()->vertical(strides[1]);
-    stride_copy->stride()->horizontal(strides[2]);
-  }
-  else if (data_layout == plier::tf::DataLayout::NCHW)
-  {
-    stride_copy->stride()->vertical(strides[2]);
-    stride_copy->stride()->horizontal(strides[3]);
-  }
-  else
-  {
-    throw std::runtime_error{"Not supported for other data layout"};
-  }
-  node->annot(std::move(stride_copy));
+  update_stride_data(node);
 
   auto stride_data = node->annot<StrideData>();
   assert(stride_data != nullptr);
@@ -1414,45 +1335,8 @@ bool fix_shape(moco::tf::TFMaxPool *node)
   auto padding = node->padding();
   assert(padding == "VALID" || padding == "SAME");
 
-  // TODO move this to some new Transformation...
-  {
-    {
-      auto stride_data = node->annot<StrideData>();
-      assert(stride_data == nullptr);
-    }
-    auto stride_data = stdex::make_unique<StrideData>();
-    auto strides = node->strides();
-    auto data_layout = plier::tf::as_data_layout(node->data_layout());
-    if (data_layout == plier::tf::DataLayout::NHWC)
-    {
-      stride_data->stride()->vertical(strides[1]);
-      stride_data->stride()->horizontal(strides[2]);
-    }
-    else if (data_layout == plier::tf::DataLayout::NCHW)
-    {
-      stride_data->stride()->vertical(strides[2]);
-      stride_data->stride()->horizontal(strides[3]);
-    }
-    node->annot(std::move(stride_data));
-
-    {
-      auto window_data = node->annot<WindowData>();
-      assert(window_data == nullptr);
-    }
-    auto window_data = stdex::make_unique<WindowData>();
-    auto ksize = node->ksize();
-    if (data_layout == plier::tf::DataLayout::NHWC)
-    {
-      window_data->window()->vertical(ksize[1]);
-      window_data->window()->horizontal(ksize[2]);
-    }
-    else if (data_layout == plier::tf::DataLayout::NCHW)
-    {
-      window_data->window()->vertical(ksize[2]);
-      window_data->window()->horizontal(ksize[3]);
-    }
-    node->annot(std::move(window_data));
-  }
+  update_stride_data(node);
+  update_window_data(node);
 
   auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());