[moco-tf] ShapeInf rev for DepthwiseConv2dNative (#7974)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 8 Oct 2019 07:19:28 +0000 (16:19 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 8 Oct 2019 07:19:28 +0000 (16:19 +0900)
This will revise shape inference for DepthwiseConv2dNative to be done in TFShapeInferenceRule

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp
compiler/moco-tf/src/TFShapeInferenceHelper.cpp
compiler/moco-tf/src/Transforms/FixShapeTransform.cpp

index 561e04e..059b71b 100644 (file)
@@ -15,9 +15,7 @@
  */
 
 #include "DepthwiseConv2dNativeCanonicalizer.h"
-
-#include "Annotations/PadData.h"
-#include "Annotations/StrideData.h"
+#include "TFShapeInferenceHelper.h"
 
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
@@ -88,20 +86,24 @@ bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::tf::TFDepthwis
   set_filter_enc(filter_enc);
   set_feature_dec(feature_dec, data_layout);
 
-  // Set DetphwiseConv2D attributes from TFDepthwiseConv2dNative
-  auto pad_data = node->annot<moco::tf::PadData>();
-  assert(pad_data != nullptr);
+  // Calculate Pad and Stride from inference
+  auto input_shape = moco::tf::node_shape(node->input());
+  auto ker_shape = moco::tf::node_shape(node->filter());
+  auto ker_tensor_shape = ker_shape.as<loco::TensorShape>();
+  auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+  auto node_window = moco::tf::window_of(ker_tensor_shape, "HWCM");
+
+  moco::tf::Padding2DInference infer_padding2d;
 
-  depthwiseconv2d->pad()->top(pad_data->pad()->top());
-  depthwiseconv2d->pad()->bottom(pad_data->pad()->bottom());
-  depthwiseconv2d->pad()->left(pad_data->pad()->left());
-  depthwiseconv2d->pad()->right(pad_data->pad()->right());
+  infer_padding2d.padding(node->padding());
+  infer_padding2d.stride(node_stride);
+  infer_padding2d.window(node_window);
 
-  auto stride_data = node->annot<moco::tf::StrideData>();
-  assert(stride_data != nullptr);
+  auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+  auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
 
-  depthwiseconv2d->stride()->vertical(stride_data->stride()->vertical());
-  depthwiseconv2d->stride()->horizontal(stride_data->stride()->horizontal());
+  *depthwiseconv2d->pad() = infer_padding2d(input_plane_shape);
+  *depthwiseconv2d->stride() = node_stride;
 
   // update graph
   auto node_A = node->input();
index 531693a..ab34bf1 100644 (file)
@@ -199,6 +199,34 @@ public:
     return loco::NodeShape(output_tensor_shape);
   }
 
+  loco::NodeShape visit(const moco::tf::TFDepthwiseConv2dNative *node) final
+  {
+    auto input_shape = moco::tf::node_shape(node->input()); // NHWC
+    auto ker_shape = moco::tf::node_shape(node->filter());
+    auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
+    auto node_stride = moco::tf::stride_of(node->strides(), node->data_layout());
+    auto node_window = moco::tf::window_of(ker_tensor_shape, "HWCM");
+
+    moco::tf::PlaneInference infer_plane_shape;
+
+    infer_plane_shape.padding(node->padding());
+    infer_plane_shape.stride(node_stride);
+    infer_plane_shape.window(node_window);
+
+    auto input_feature_shape = moco::tf::as_feature_shape(input_shape, node->data_layout());
+    auto input_plane_shape = moco::tf::make_plane_shape(input_feature_shape);
+    // output count is from input count, depth is from kernel 'CM' which is dim(2) * dim(3)
+    auto output_feature_shape = input_feature_shape;
+    output_feature_shape.depth() =
+        loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
+
+    auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+    moco::tf::update(output_feature_shape).with(output_plane_shape);
+
+    return moco::tf::as_tensor_shape(output_feature_shape, node->data_layout());
+  }
+
 public:
   loco::NodeShape visit(const moco::tf::TFNode *node) final
   {
index 1a732e8..a7604b9 100644 (file)
@@ -348,6 +348,11 @@ loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &da
     window.vertical(shape.dim(0).value());
     window.horizontal(shape.dim(1).value());
   }
+  else if (datalayout == "HWCM")
+  {
+    window.vertical(shape.dim(0).value());
+    window.horizontal(shape.dim(1).value());
+  }
   else
   {
     // TODO add more datalayout supports if needed
index 493be50..9561f83 100644 (file)
@@ -583,91 +583,7 @@ bool fix_shape(moco::tf::TFConv2DBackpropInput *node)
   return true;
 }
 
-bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node)
-{
-  LOGGER(l);
-
-  if (shape_inference_done(node))
-    return false;
-
-  auto ifm = node->input();
-  loco::NodeShape ifm_shape;
-  if (!node_shape(ifm, ifm_shape))
-  {
-    // input node shape inference is not ready
-    return false;
-  }
-
-  auto ker = node->filter();
-  loco::NodeShape ker_shape;
-  if (!node_shape(ker, ker_shape))
-  {
-    return false;
-  }
-
-  update_stride_data(node);
-
-  auto stride_data = node->annot<StrideData>();
-  assert(stride_data != nullptr);
-
-  INFO(l) << "FixShape TFDepthwiseConv2dNative strides = " << stride_data->stride()->vertical()
-          << ", " << stride_data->stride()->horizontal();
-
-  auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); // in NHWC
-  auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
-  assert(ifm_tensor_shape.rank() == 4);
-  assert(ker_tensor_shape.rank() == 4);
-
-  uint32_t input_height = ifm_tensor_shape.dim(1).value();
-  uint32_t input_width = ifm_tensor_shape.dim(2).value();
-  uint32_t stride_height = stride_data->stride()->vertical();
-  uint32_t stride_width = stride_data->stride()->horizontal();
-  uint32_t ker_height = ker_tensor_shape.dim(0).value();
-  uint32_t ker_width = ker_tensor_shape.dim(1).value();
-  uint32_t dilation_height = 1; // TODO Consider dilation
-  uint32_t dilation_width = 1;
-  uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
-  uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
-  uint32_t output_height;
-  uint32_t output_width;
-
-  auto padding = node->padding();
-  assert(padding == "VALID" || padding == "SAME");
-
-  if (padding == "VALID")
-  {
-    output_height = (input_height + stride_height - effective_ker_height) / stride_height;
-    output_width = (input_width + stride_width - effective_ker_width) / stride_width;
-  }
-  else // padding == "SAME"
-  {
-    output_height = (input_height + stride_height - 1) / stride_height;
-    output_width = (input_width + stride_width - 1) / stride_width;
-  }
-
-  loco::TensorShape ofm_tensor_shape;
-  ofm_tensor_shape.rank(4);
-  ofm_tensor_shape.dim(0) = ifm_tensor_shape.dim(0);
-  ofm_tensor_shape.dim(1) = output_height;
-  ofm_tensor_shape.dim(2) = output_width;
-  ofm_tensor_shape.dim(3) =
-      loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
-
-  auto shape_data = stdex::make_unique<ShapeInferenceData>();
-  shape_data->tensor_shape(ofm_tensor_shape);
-  node->annot(std::move(shape_data));
-
-  FixPadContext ctx = {input_height,  input_width,  output_height,        output_width,
-                       stride_height, stride_width, effective_ker_height, effective_ker_width};
-
-  calc_annot_paddata(node, ctx);
-
-  INFO(l) << "Fix TFDepthwiseConv2dNative shape = ifm" << ifm_tensor_shape << " ker"
-          << ker_tensor_shape << " --> ofm" << ofm_tensor_shape;
-  INFO(l) << "                            pad = " << *node->annot<PadData>();
-
-  return true;
-}
+bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node) { return false; }
 
 bool fix_shape(moco::tf::TFFusedBatchNorm *node)
 {