[moco/tf] Implement pad fixing for AvgPool2D (#3790)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 14 Jun 2019 08:09:28 +0000 (17:09 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 14 Jun 2019 08:09:28 +0000 (17:09 +0900)
This will implement pad fixing for AvgPool2D

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Transforms/FixPaddingTransform.cpp

index 3695cb5..a40dfc4 100644 (file)
 
 #include "FixPaddingTransform.h"
 
+#include "Convert.h"
+#include "Annotations/PaddingData.h"
+#include "Annotations/ShapeInferenceData.h"
+
 #include <loco.h>
 
+#include <cassert>
 #include <stdexcept>
 
+/**
+ * @note To fix padding, output shape of the node needs to be fixed first.
+ *       fix_padding() checks if fix padding is needed by existance of
+ *       PaddingData annotation, then output shape is checked by existance
+ *       of ShapeInferenceData.
+ */
+
 namespace
 {
 
@@ -27,8 +39,55 @@ using namespace moco::tf;
 
 bool fix_padding(loco::AvgPool2D *node)
 {
-  // TODO implement this
-  return false;
+  auto padding_data = node->annot<PaddingData>();
+  if (padding_data == nullptr)
+  {
+    // padding conversion is already done
+    return false;
+  }
+  auto ofm_shapedata = node->annot<ShapeInferenceData>();
+  if (ofm_shapedata == nullptr)
+  {
+    // need output shape to calculate padding values
+    return false;
+  }
+  auto ifm = node->ifm();
+  assert(ifm != nullptr);
+  auto ifm_shapedata = ifm->annot<ShapeInferenceData>();
+  if (ifm_shapedata == nullptr)
+  {
+    // need input shape to calculate padding values
+    return false;
+  }
+
+  auto padding = padding_data->padding();
+  padding = moco::str_toupper(padding);
+  assert(padding == "VALID" || padding == "SAME");
+  assert(ofm_shapedata->rank() == 4);
+  assert(ifm_shapedata->rank() == 4);
+
+  auto ifm_feature_shape = ifm_shapedata->feature_shape();
+  auto ofm_feature_shape = ofm_shapedata->feature_shape();
+
+  uint32_t input_height = ifm_feature_shape.height().value();
+  uint32_t input_width = ifm_feature_shape.width().value();
+  uint32_t stride_height = node->stride()->vertical();
+  uint32_t stride_width = node->stride()->horizontal();
+  uint32_t output_height = ofm_feature_shape.height().value();
+  uint32_t output_width = ofm_feature_shape.width().value();
+  // calculate padding height, width
+  uint32_t height = std::max(0U, ((output_height - 1) * stride_height + 1 - input_height) / 2);
+  uint32_t width = std::max(0U, ((output_width - 1) * stride_width + 1 - input_width) / 2);
+
+  // set padding values
+  node->pad()->top(height / 2);
+  node->pad()->bottom(height - node->pad()->top());
+  node->pad()->left(width / 2);
+  node->pad()->right(width - node->pad()->left());
+
+  // clear annotation PaddingData
+  node->annot<PaddingData>(nullptr);
+  return true;
 }
 
 bool fix_padding(loco::ConstGen *node)