[Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum. (#4059)
authorAnimesh Jain <anijain@umich.edu>
Fri, 11 Oct 2019 05:57:09 +0000 (22:57 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 11 Oct 2019 05:57:09 +0000 (22:57 -0700)
python/tvm/relay/frontend/tflite.py
src/relay/op/nn/pad.cc
src/relay/op/nn/pooling.cc
src/relay/op/tensor/reduce.cc
src/relay/op/tensor/transform.cc
tests/python/relay/test_pass_alter_op_layout.py

index 8b91315..a519c6f 100644 (file)
@@ -748,10 +748,12 @@ class OperatorConverter(object):
         elif padding == Padding.SAME:
             pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
             pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
-            in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
-                                                          (pad_top, pad_bottom),
-                                                          (pad_left, pad_right),
-                                                          (0, 0)))
+            do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
+            if do_pad:
+                in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
+                                                              (pad_top, pad_bottom),
+                                                              (pad_left, pad_right),
+                                                              (0, 0)))
         else:
             raise tvm.error.OpAttributeUnImplemented(
                 'Padding format {} is not supported for operator Conv.'.format(padding))
index 331e50f..5127ee4 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
index 76dec99..503db41 100644 (file)
@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
   T *params = const_cast<T*>(attrs.as<T>());
 
   if (new_in_layouts.defined()) {
+    // Set the pool with the new layout.
     CHECK_EQ(new_in_layouts.size(), 1);
-
-    Layout raw_layout(params->layout);
-    Layout input = new_in_layouts[0];
-    if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
-    input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
-        !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
-      params->layout = input.name();  // modify self to follow the input layout
-    }
+    params->layout = new_in_layouts[0].name();
   }
 
   Layout inferred_layout(params->layout);
index a7be3ff..e41cfda 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim,
   return r_axes;
 }
 
+// Return the modified layout for AlterOpLayout pass.
+Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
+                                              const Array<Layout>& new_in_layouts,
+                                              const Array<Layout>& old_in_layouts,
+                                              const Array<Array<IndexExpr>>& old_in_shapes) {
+  // NOTE: Discard "const" qualifier here.
+  ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
+
+  // Get the reduce axes.
+  uint32_t indim = old_in_shapes[0].size();
+  auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
+
+  Layout ret = Layout::Undef();
+  if (new_in_layouts.defined() && r_axes.size()) {
+    // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
+    // modified layout axes.
+    CHECK_EQ(new_in_layouts.size(), 1);
+    CHECK_EQ(old_in_layouts.size(), 1);
+
+    // 1) Collect the original axes
+    std::unordered_set<std::string> old_r_dims;
+    for (auto r_axis : r_axes) {
+      old_r_dims.emplace(old_in_layouts[0][r_axis].name());
+    }
+
+    // 2) Collect the new axes by walking new_layout.
+    tvm::Array<tvm::Integer> new_r_axes;
+    std::string new_layout_string = "";
+    int axis_index = 0;
+    for (auto iter_var : new_in_layouts[0]->axes) {
+      const auto& layout_axis = LayoutAxis::Get(iter_var);
+      const std::string& layout_dim = layout_axis.name();
+      if (old_r_dims.count(layout_dim)) {
+        new_r_axes.push_back(tvm::Integer(axis_index));
+      }
+      // Collect only the primal axis.
+      if (layout_axis.IsPrimal()) {
+        new_layout_string += layout_dim;
+        axis_index++;
+      }
+    }
+
+    // 3) Set the new axis and layout.
+    ret = Layout(new_layout_string);
+    params->axis = new_r_axes;
+  } else if (old_in_layouts.defined()) {
+    // If the new layout is undefined, set the old layout as the inferred layout.
+    CHECK_EQ(old_in_layouts.size(), 1);
+    ret = old_in_layouts[0];
+  }
+
+  return Array<Array<Layout>>{{ret}, {ret}};
+}
 
 template<typename F>
 Array<Tensor> ReduceCompute(const Attrs& attrs,
@@ -325,6 +378,7 @@ Example::
 .set_attrs_type_key("relay.attrs.ReduceAttrs")
 .set_support_level(4)
 .add_type_rel("Reduce", ReduceRel)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
 .set_attr<FTVMCompute>("FTVMCompute", SumCompute)
 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
index 3f371f2..1d0c9ec 100644 (file)
@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout(
     const Array<Layout>& new_in_layouts,
     const Array<Layout>& old_in_layouts,
     const Array<Array<IndexExpr>> &old_in_shapes) {
-  const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
+  ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
 
   size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
                 static_cast<size_t>(param->axis);
 
   Layout ret;
+  bool is_new_layout_selected = false;
   if (new_in_layouts.defined()) {  // this function is called after some operators are alternated.
+    // If all the new input layouts are same, the new in layout gets selected.  For axis, the new
+    // axis in the new layout is identified. The param->axis is then modified on the fly to conform
+    // to the new input layout.
     const auto& concate_dim = old_in_layouts[0][axis];
-    for (size_t i = 0; i < new_in_layouts.size(); ++i) {
-      if (new_in_layouts[i].ndim() > axis &&
-          new_in_layouts[i][axis] == concate_dim) {
-        ret = new_in_layouts[i];
-        break;
+    bool all_input_layouts_same = true;
+    for (auto new_layout : new_in_layouts) {
+      if (!new_layout.Equals(new_in_layouts[0])) {
+        all_input_layouts_same = false;
       }
     }
-  } else {  // this function is called on the original correct relay ir
+    if (all_input_layouts_same) {
+      auto new_index = new_in_layouts[0].IndexOf(concate_dim);
+      ret = new_in_layouts[0];
+      param->axis = new_index;
+      is_new_layout_selected = true;
+    }
+  }
+
+  if (!is_new_layout_selected) {
+    // this function is called on the original correct relay ir
     for (size_t i = 0; i < old_in_layouts.size(); ++i) {
       if (old_in_layouts[i].defined()) {
         ret = old_in_layouts[i];
index cc668d7..adb8676 100644 (file)
@@ -45,6 +45,7 @@ def test_alter_op():
         y = relay.Function([x, weight], y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=100)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -79,6 +80,7 @@ def test_alter_return_none():
 
     called = [False]
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.global_max_pool2d", level=101)
     def alter_conv2d(attrs, inputs, tinfos):
         called[0] = True
@@ -112,6 +114,7 @@ def test_alter_layout():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=102)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -180,6 +183,7 @@ def test_alter_layout_dual_path():
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=103)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -241,6 +245,7 @@ def test_alter_layout_resnet():
         y = relay.nn.global_max_pool2d(y)
         return relay.Function(analysis.free_vars(y), y)
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=104)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -291,6 +296,7 @@ def test_alter_layout_broadcast_op():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=105)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -338,6 +344,7 @@ def test_alter_layout_scalar():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=106)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -370,9 +377,19 @@ def test_alter_layout_scalar():
 
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
+
 def test_alter_layout_concatenate():
-    """ """
-    def before():
+    """ NCHW, NHWC and corner case concatenate layout transform."""
+    # Register alter op layout. "level" is used to override the previously registered functions.
+    @register_alter_op_layout("nn.conv2d", level=107)
+    def alter_conv2d(attrs, inputs, tinfos):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs['data_layout'] = 'NCHW16c'
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    # NCHW layout transformation.
+    def before_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight1 = relay.var('weight1')
         weight2 = relay.var('weight2')
@@ -388,14 +405,7 @@ def test_alter_layout_concatenate():
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
-    @register_alter_op_layout("nn.conv2d", level=107)
-    def alter_conv2d(attrs, inputs, tinfos):
-        data, weight = inputs
-        new_attrs = dict(attrs)
-        new_attrs['data_layout'] = 'NCHW16c'
-        return relay.nn.conv2d(data, weight, **new_attrs)
-
-    def expected():
+    def expected_nchw():
         x = relay.var("x", shape=(1, 64, 56, 56))
         weight1 = relay.var('weight1')
         weight2 = relay.var('weight2')
@@ -415,10 +425,57 @@ def test_alter_layout_concatenate():
         y = relay.Function(analysis.free_vars(ret), ret)
         return y
 
-    a = before()
+    a = before_nchw()
     a = run_opt_pass(a, transform.AlterOpLayout())
 
-    b = expected()
+    b = expected_nchw()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+    # NHWC layout transformation.
+    def before_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        weight2 = relay.var('weight2')
+        y = relay.nn.conv2d(x, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NHWC')
+        y1 = relay.nn.conv2d(y, weight2,
+                             channels=32,
+                             kernel_size=(3, 3),
+                             padding=(1, 1),
+                             data_layout='NHWC')
+        ret = relay.concatenate([y, y1], axis=3)
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    def expected_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        weight2 = relay.var('weight2')
+        y = relay.layout_transform(x, "NHWC", "NCHW16c")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NCHW16c")
+        y1 = relay.nn.conv2d(y, weight2,
+                             channels=32,
+                             kernel_size=(3, 3),
+                             padding=(1, 1),
+                             data_layout='NCHW16c')
+        ret = relay.concatenate([y, y1], axis=1)
+        ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before_nhwc()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nhwc()
     b = run_opt_pass(b, transform.InferType())
 
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
@@ -435,6 +492,7 @@ def test_alter_layout_nchw_upsamping_op():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=108)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -474,6 +532,7 @@ def test_alter_layout_strided_slice():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=109)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -511,6 +570,7 @@ def test_alter_layout_depthwise_conv2d():
         return y
 
     import topi
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=110)
     def alter_conv2d(attrs, inputs, tinfos):
         with tvm.target.create("llvm"):
@@ -548,6 +608,7 @@ def test_alter_layout_prelu():
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
+    # Register alter op layout. "level" is used to override the previously registered functions.
     @register_alter_op_layout("nn.conv2d", level=111)
     def alter_conv2d(attrs, inputs, tinfos):
         data, weight = inputs
@@ -580,6 +641,167 @@ def test_alter_layout_prelu():
     assert(analysis.alpha_equal(a, b))
 
 
+def test_alter_layout_pool():
+    """ Check NCHW, NHWC pool layout conversion"""
+    # Register alter op layout. "level" is used to override the previously registered functions.
+    @register_alter_op_layout("nn.conv2d", level=113)
+    def alter_conv2d(attrs, inputs, tinfos):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs['data_layout'] = 'NCHW16c'
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    # Check NCHW conversion.
+    def before_nchw():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var('weight1')
+        y = relay.nn.conv2d(x, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1))
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    def expected_nchw():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var('weight1')
+        y = relay.layout_transform(x, "NCHW", "NCHW16c")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NCHW16c")
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
+        ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before_nchw()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nchw()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+    # Check NHWC conversion.
+    def before_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        y = relay.nn.conv2d(x, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NHWC')
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NHWC')
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    def expected_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        y = relay.layout_transform(x, "NHWC", "NCHW16c")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NCHW16c")
+        ret = relay.nn.avg_pool2d(y, pool_size=(1, 1), layout='NCHW16c')
+        ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before_nhwc()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nhwc()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
+def test_alter_layout_sum():
+    """ Check NCHW, NHWC sum layout conversion"""
+    # Register alter op layout. "level" is used to override the previously registered functions.
+    @register_alter_op_layout("nn.conv2d", level=114)
+    def alter_conv2d(attrs, inputs, tinfos):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs['data_layout'] = 'NCHW16c'
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    # Check NCHW conversion.
+    def before_nchw():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var('weight1')
+        y = relay.nn.conv2d(x, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        ret = relay.sum(y, axis=1, keepdims=True)
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    def expected_nchw():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var('weight1')
+        y = relay.layout_transform(x, "NCHW", "NCHW16c")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NCHW16c")
+        ret = relay.layout_transform(y, "NCHW16c", "NCHW")
+        ret = relay.sum(ret, axis=[1], keepdims=True)
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before_nchw()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nchw()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+    # Check NHWC conversion.
+    def before_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        y = relay.nn.conv2d(x, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout='NHWC')
+        ret = relay.sum(y, axis=3, keepdims=True)
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    def expected_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var('weight1')
+        y = relay.layout_transform(x, "NHWC", "NCHW16c")
+        y = relay.nn.conv2d(y, weight1,
+                            channels=32,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            data_layout="NCHW16c")
+        ret = relay.layout_transform(y, "NCHW16c", "NCHW")
+        ret = relay.sum(ret, axis=[1], keepdims=True)
+        ret = relay.layout_transform(ret, "NCHW", "NHWC")
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before_nhwc()
+    a = run_opt_pass(a, transform.AlterOpLayout())
+
+    b = expected_nhwc()
+    b = run_opt_pass(b, transform.InferType())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+
+
 if __name__ == "__main__":
     test_alter_op()
     test_alter_return_none()
@@ -593,3 +815,5 @@ if __name__ == "__main__":
     test_alter_layout_strided_slice()
     test_alter_layout_depthwise_conv2d()
     test_alter_layout_prelu()
+    test_alter_layout_pool()
+    test_alter_layout_sum()