[moco-tf] Implement Conv2DBackpropInput canonicalization partially (#7637)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 20 Sep 2019 03:27:32 +0000 (12:27 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 20 Sep 2019 03:27:32 +0000 (12:27 +0900)
This commit implements Conv2DBackpropInput canonicalization, except
padding inference.

compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp

index a804ceb..bccb853 100644 (file)
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
+
+#include <loco/IR/Stride.h>
+#include <loco/IR/Padding2D.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+  auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+  if (data_layout == DataLayout::NHWC)
+  {
+    enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+    enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+    enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+    enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+  }
+  else if (data_layout == DataLayout::NCHW)
+  {
+    enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+    enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+    enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+    enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+  }
+  else
+    throw std::runtime_error("Not supported data layout");
+
+  feature_enc->encoder(std::move(enc));
+}
+
+void set_filter_enc(loco::FilterEncode *filter_enc)
+{
+  auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+  // In TensorFlow, Conv2dBackpropInput's filter is a 4-D tensor of following shape:
+  // [filter_height, filter_width, out_channels, in_channels] or HWOI or HWNC (in/out in loco sense)
+  enc->perm()->axis(loco::FilterAxis::Height) = 0;
+  enc->perm()->axis(loco::FilterAxis::Width) = 1;
+  enc->perm()->axis(loco::FilterAxis::Count) = 2;
+  enc->perm()->axis(loco::FilterAxis::Depth) = 3;
+
+  filter_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+  auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+  if (data_layout == DataLayout::NHWC)
+  {
+    dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+    dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+    dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+    dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+  }
+  else if (data_layout == DataLayout::NCHW)
+  {
+    dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+    dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+    dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+    dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+  }
+  else
+    throw std::runtime_error("Not supported data layout");
+
+  feature_dec->decoder(std::move(dec));
+}
+
+} // namespace
+
+namespace
+{
+
+loco::Stride<2> stride_2d_from_4d(const std::vector<int64_t> &strides_4d,
+                                  const DataLayout data_layout)
+{
+  assert(strides_4d.size() == 4);
+
+  loco::Stride<2> ret;
+  switch (data_layout)
+  {
+    case DataLayout::NHWC:
+      ret.vertical(strides_4d.at(1));
+      ret.horizontal(strides_4d.at(2));
+      break;
+    case DataLayout::NCHW:
+      ret.vertical(strides_4d.at(2));
+      ret.horizontal(strides_4d.at(3));
+      break;
+    default:
+      throw std::runtime_error("Not supported data layout");
+  }
+  return ret;
+}
+
+struct PlaneShape
+{
+  loco::Dimension vertical;
+  loco::Dimension horizontal;
+};
+
+class Padding2DInference final
+{
+public:
+  loco::Padding2D operator()(void);
+
+public:
+  PlaneShape &input() { return _input; }
+  PlaneShape &output() { return _output; }
+  loco::Stride<2> &stride() { return _stride; }
+  loco::Window<2> &window() { return _window; }
+  moco::tf::TFPadding &padding() { return _padding; }
+
+private:
+  PlaneShape _input;
+  PlaneShape _output;
+  loco::Stride<2> _stride;
+  loco::Window<2> _window;
+  moco::tf::TFPadding _padding;
+};
+
+loco::Padding2D Padding2DInference::operator()(void)
+{
+  // TODO Implement!
+  throw std::runtime_error("NYI");
+}
+
+/**
+ * @param[out] ret  PlaneShape extracted from 'node' with given 'data_layout'
+ * @param[in]  node
+ * @param[in]  data_layout
+ *
+ * @return true on success
+ */
+bool set_plane_shape(PlaneShape &ret, const loco::Node *node, const DataLayout data_layout)
+{
+  if (!loco::shape_known(node))
+    return false;
+
+  auto tensor_shape = loco::shape_get(node).as<loco::TensorShape>();
+  assert(tensor_shape.rank() == 4);
+
+  switch (data_layout)
+  {
+    case DataLayout::NHWC:
+      ret.vertical = tensor_shape.dim(1).value();
+      ret.horizontal = tensor_shape.dim(2).value();
+      break;
+    case DataLayout::NCHW:
+      ret.vertical = tensor_shape.dim(2).value();
+      ret.horizontal = tensor_shape.dim(3).value();
+      break;
+    default:
+      throw std::runtime_error("Not supported data layout");
+  }
+
+  return true;
+}
+
+/**
+ * @param[out] ret  2D Window extracted from HW** filter node
+ * @param[in]  filter_node
+ *
+ * @return true on success
+ */
+bool set_window(loco::Window<2> &ret, const loco::Node *filter_node)
+{
+  if (!loco::shape_known(filter_node))
+    return false;
+
+  auto tensor_shape = loco::shape_get(filter_node).as<loco::TensorShape>();
+  assert(tensor_shape.rank() == 4);
+
+  ret.vertical(tensor_shape.dim(0).value());
+  ret.horizontal(tensor_shape.dim(1).value());
+
+  return true;
+}
+
+} // namespace
+
 namespace
 {
 
 bool canonicalize_conv2d_backprop_input(loco::Graph *graph,
                                         moco::tf::TFConv2DBackpropInput *conv2d_backprop)
 {
-  // TODO Implement!
-  return false;
+  /**
+   * @note This will replace TFConv2DBackpropInput node with canonical
+   *       FeatureEncode + FilterEncode + TransposedConv2D + FeatureDecode
+   *
+   * Before
+   *           input_sizes ----
+   *                           \
+   *           filter -------- TFConv2DBackpropInput --- output(s)
+   *                           /
+   *           out_backprop ---
+   *
+   * After
+   *           input_sizes ----
+   *                           \
+   *           filter -------- TFConv2DBackpropInput ---
+   *                           /
+   *           out_backprop ---
+   *
+   *           filter ------ FilterEncode ------ TransposedConv2D --- FeatureDecode --- output(s)
+   *                          (as ker)           /
+   *           out_backprop --- FeatureEncode ---
+   *                             (as ifm)
+   */
+
+  auto data_layout = plier::tf::as_data_layout(conv2d_backprop->data_layout());
+
+  // Nodes to replace
+  auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
+  auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
+  auto tr_conv2d = graph->nodes()->create<loco::TransposedConv2D>();
+  auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
+
+  set_feature_enc(feature_enc, data_layout);
+  set_filter_enc(filter_enc);
+  set_feature_dec(feature_dec, data_layout);
+
+  // Attributes for new TransposedConv2D
+  loco::Stride<2> stride;
+  loco::Padding2D pad;
+
+  // Get attributes
+  {
+    stride = stride_2d_from_4d(conv2d_backprop->strides(), data_layout);
+
+    Padding2DInference infer_pad;
+
+    if (!set_plane_shape(infer_pad.input(), conv2d_backprop->out_backprop(), data_layout))
+      return false;
+    if (!set_plane_shape(infer_pad.output(), conv2d_backprop, data_layout))
+      return false;
+    if (!set_window(infer_pad.window(), conv2d_backprop->filter()))
+      return false;
+    infer_pad.stride() = stride;
+    infer_pad.padding() = conv2d_backprop->padding();
+
+    // Run padding infer_pad
+    pad = infer_pad();
+  }
+
+  // Set attributes
+  tr_conv2d->pad()->top(pad.top());
+  tr_conv2d->pad()->bottom(pad.bottom());
+  tr_conv2d->pad()->left(pad.left());
+  tr_conv2d->pad()->right(pad.right());
+
+  tr_conv2d->stride()->vertical(stride.vertical());
+  tr_conv2d->stride()->horizontal(stride.horizontal());
+
+  // Update graph
+  auto input_node = conv2d_backprop->out_backprop();
+  auto filter_node = conv2d_backprop->filter();
+
+  // Update connections
+  feature_enc->input(input_node);
+  filter_enc->input(filter_node);
+  tr_conv2d->ifm(feature_enc);
+  tr_conv2d->ker(filter_enc);
+  feature_dec->input(tr_conv2d);
+
+  // Replace old conv2d_backprop
+  replace(conv2d_backprop).with(feature_dec);
+
+  return true;
 }
 
 } // namespace