[moco-tf] Fix shape and pad for TFAvgPool (#5996)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 30 Jul 2019 03:43:35 +0000 (12:43 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 30 Jul 2019 03:43:35 +0000 (12:43 +0900)
This will implement FixShape and FixPadding for TFAvgPool node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Transforms/FixPaddingTransform.cpp
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 3d89717..f49bd9b 100644 (file)
@@ -20,6 +20,7 @@
 #include "Annotations/PadData.h"
 #include "Annotations/ShapeInferenceData.h"
 #include "Annotations/StrideData.h"
+#include "Annotations/WindowData.h"
 #include "Dialect/TFNodes.h"
 
 #include <loco.h>
@@ -328,8 +329,91 @@ bool fix_padding(moco::tf::TFAdd *node)
 
 bool fix_padding(moco::tf::TFAvgPool *node)
 {
-  // Nothing to do with padding
-  return false;
+  LOGGER(l);
+
+  auto pad_data_c = node->annot<PadData>();
+  if (pad_data_c != nullptr)
+  {
+    // padding conversion is already done
+    return false;
+  }
+
+  auto ofm_shapedata = node->annot<ShapeInferenceData>();
+  if (ofm_shapedata == nullptr)
+  {
+    // need output shape to calculate padding values
+    return false;
+  }
+  auto value = node->value();
+  assert(value != nullptr);
+  auto value_shapedata = value->annot<ShapeInferenceData>();
+  if (value_shapedata == nullptr)
+  {
+    // need input shape to calculate padding values
+    return false;
+  }
+  auto stride_data = node->annot<StrideData>();
+  if (stride_data == nullptr)
+  {
+    // need stride_data from FixShape
+    return false;
+  }
+  auto window_data = node->annot<WindowData>();
+  if (window_data == nullptr)
+  {
+    // need window_data from FixShape
+    return false;
+  }
+
+  auto padding = node->padding();
+  assert(padding == "VALID" || padding == "SAME");
+  assert(ofm_shapedata->rank() == 4);
+  assert(value_shapedata->rank() == 4);
+
+  auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());
+  auto ofm_feature_shape = as_feature_shape(*ofm_shapedata, node->data_layout());
+
+  uint32_t input_height = value_feature_shape.height().value();
+  uint32_t input_width = value_feature_shape.width().value();
+  uint32_t stride_height = stride_data->stride()->vertical();
+  uint32_t stride_width = stride_data->stride()->horizontal();
+  uint32_t window_height = window_data->window()->vertical();
+  uint32_t window_width = window_data->window()->horizontal();
+  uint32_t output_height = ofm_feature_shape.height().value();
+  uint32_t output_width = ofm_feature_shape.width().value();
+  uint32_t dilation_height = 1; // dilation for AvgPool is 1
+  uint32_t dilation_width = 1;
+  uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+  uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+  // calculate padding height, width
+  int32_t i_height = (output_height - 1) * stride_height + effective_window_height - input_height;
+  int32_t i_width = (output_width - 1) * stride_width + effective_window_width - input_width;
+  uint32_t height = i_height >= 0 ? i_height : 0U;
+  uint32_t width = i_width >= 0 ? i_width : 0U;
+
+  // annotation of pad data
+  auto pad_data = stdex::make_unique<PadData>();
+
+  pad_data->pad()->top(height / 2);
+  pad_data->pad()->bottom(height - pad_data->pad()->top());
+  pad_data->pad()->left(width / 2);
+  pad_data->pad()->right(width - pad_data->pad()->left());
+
+  node->annot(std::move(pad_data));
+
+  {
+    auto pad_data = node->annot<PadData>();
+    assert(pad_data != nullptr);
+
+    // clang-format off
+    INFO(l) << "Fix TFAvgPool pad "
+            << "= T " << pad_data->pad()->top()
+            << ", L " << pad_data->pad()->left()
+            << ", B " << pad_data->pad()->bottom()
+            << ", R " << pad_data->pad()->right() << std::endl;
+    // clang-format on
+  }
+  return true;
 }
 
 bool fix_padding(moco::tf::TFBiasAdd *node)
index c75f347..36cc79d 100644 (file)
@@ -22,6 +22,7 @@
 #include "Annotations/PaddingData.h"
 #include "Annotations/ShapeInferenceData.h"
 #include "Annotations/StrideData.h"
+#include "Annotations/WindowData.h"
 #include "Dialect/TFNodes.h"
 
 #include <loco.h>
@@ -564,8 +565,109 @@ bool fix_shape(moco::tf::TFAdd *node)
 
 bool fix_shape(moco::tf::TFAvgPool *node)
 {
-  // TODO implement
-  throw std::runtime_error("NYI fix_shape TFAvgPool");
+  LOGGER(l);
+
+  auto shapedata = node->annot<ShapeInferenceData>();
+  if (shapedata != nullptr)
+  {
+    // shape inference is already done for TFAvgPool
+    return false;
+  }
+  auto value = node->value();
+  auto value_shapedata = value->annot<ShapeInferenceData>();
+  if (value_shapedata == nullptr)
+  {
+    // input node shape inference is not ready
+    return false;
+  }
+
+  auto padding = node->padding();
+  assert(padding == "VALID" || padding == "SAME");
+
+  // TODO move this to some new Transformation...
+  {
+    {
+      auto stride_data = node->annot<StrideData>();
+      assert(stride_data == nullptr);
+    }
+    auto stride_data = stdex::make_unique<StrideData>();
+    auto strides = node->strides();
+    auto data_layout = plier::tf::as_data_layout(node->data_layout());
+    if (data_layout == plier::tf::DataLayout::NHWC)
+    {
+      stride_data->stride()->vertical(strides[1]);
+      stride_data->stride()->horizontal(strides[2]);
+    }
+    else if (data_layout == plier::tf::DataLayout::NCHW)
+    {
+      stride_data->stride()->vertical(strides[2]);
+      stride_data->stride()->horizontal(strides[3]);
+    }
+    node->annot(std::move(stride_data));
+
+    {
+      auto window_data = node->annot<WindowData>();
+      assert(window_data == nullptr);
+    }
+    auto window_data = stdex::make_unique<WindowData>();
+    auto ksize = node->ksize();
+    if (data_layout == plier::tf::DataLayout::NHWC)
+    {
+      window_data->window()->vertical(ksize[1]);
+      window_data->window()->horizontal(ksize[2]);
+    }
+    else if (data_layout == plier::tf::DataLayout::NCHW)
+    {
+      window_data->window()->vertical(ksize[2]);
+      window_data->window()->horizontal(ksize[3]);
+    }
+    node->annot(std::move(window_data));
+  }
+
+  auto value_feature_shape = as_feature_shape(*value_shapedata, node->data_layout());
+
+  auto stride_data = node->annot<StrideData>();
+  assert(stride_data != nullptr);
+  auto window_data = node->annot<WindowData>();
+  assert(window_data != nullptr);
+
+  uint32_t input_height = value_feature_shape.height().value();
+  uint32_t input_width = value_feature_shape.width().value();
+  uint32_t stride_height = stride_data->stride()->vertical();
+  uint32_t stride_width = stride_data->stride()->horizontal();
+  uint32_t window_height = window_data->window()->vertical();
+  uint32_t window_width = window_data->window()->horizontal();
+  uint32_t dilation_height = 1; // dilation is 1
+  uint32_t dilation_width = 1;
+  uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+  uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+  uint32_t output_height;
+  uint32_t output_width;
+
+  if (padding == "VALID")
+  {
+    output_height = (input_height + stride_height - effective_window_height) / stride_height;
+    output_width = (input_width + stride_width - effective_window_width) / stride_width;
+  }
+  else if (padding == "SAME")
+  {
+    output_height = (input_height + stride_height - 1) / stride_height;
+    output_width = (input_width + stride_width - 1) / stride_width;
+  }
+
+  loco::FeatureShape ofm_feature_shape;
+  ofm_feature_shape.count() = value_feature_shape.count();
+  ofm_feature_shape.height() = output_height;
+  ofm_feature_shape.width() = output_width;
+  ofm_feature_shape.depth() = value_feature_shape.depth();
+
+  auto shape_data = stdex::make_unique<ShapeInferenceData>();
+  as_tensor_shape(*shape_data.get(), ofm_feature_shape, node->data_layout());
+  node->annot(std::move(shape_data));
+
+  INFO(l) << "Fix TFAvgPool shape = ifm" << value_feature_shape << " --> ofm" << ofm_feature_shape;
+
+  return true;
 }
 
 bool fix_shape(moco::tf::TFBiasAdd *node)