[moco-tf] Refactor Conv2DCanonicalizer (#7874)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 1 Oct 2019 09:47:32 +0000 (18:47 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 1 Oct 2019 09:47:32 +0000 (18:47 +0900)
This will update Conv2DCanonicalizer to do shape, pad inference and NOT to use annotations

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp

index 3ba9d47..51188ed 100644 (file)
@@ -15,9 +15,7 @@
  */
 
 #include "Conv2DCanonicalizer.h"
-
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
+#include "TFShapeInferenceHelper.h"
 
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
@@ -83,20 +81,26 @@ bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node)
   set_filter_enc(filter_enc);
   set_feature_dec(feature_dec, data_layout);
 
-  // Set Conv2D attributes from TFConv2D
-  auto pad_data = node->annot<moco::tf::PadData>();
-  assert(pad_data != nullptr);
+  auto input_shape = moco::tf::node_shape(node->input());
+  assert(input_shape.domain() != loco::Domain::Unknown);
+
+  auto ker_shape = moco::tf::node_shape(node->filter());
+  auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
+
+  auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+  auto node_window = moco::tf::window_of(ker_tensor_shape, "HWIO");
+
+  moco::tf::Padding2DInference infer_padding2d;
 
-  conv2d->pad()->top(pad_data->pad()->top());
-  conv2d->pad()->bottom(pad_data->pad()->bottom());
-  conv2d->pad()->left(pad_data->pad()->left());
-  conv2d->pad()->right(pad_data->pad()->right());
+  infer_padding2d.padding(node->padding());
+  infer_padding2d.stride(node_stride);
+  infer_padding2d.window(node_window);
 
-  auto stride_data = node->annot<moco::tf::StrideData>();
-  assert(stride_data != nullptr);
+  auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+  auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
 
-  conv2d->stride()->vertical(stride_data->stride()->vertical());
-  conv2d->stride()->horizontal(stride_data->stride()->horizontal());
+  *conv2d->pad() = infer_padding2d(input_plane_shape);
+  *conv2d->stride() = node_stride;
 
   // update graph
   auto node_A = node->input();