[Relay][AlterOp] Improving support for broadcast layout alteration. (#4040)
authorAnimesh Jain <anijain@umich.edu>
Sun, 6 Oct 2019 04:18:58 +0000 (21:18 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Sun, 6 Oct 2019 04:18:58 +0000 (12:18 +0800)
include/tvm/data_layout.h
src/relay/op/tensor/transform.cc
src/relay/pass/alter_op_layout.cc
src/relay/pass/alter_op_layout.h
src/relay/pass/pattern_util.h
src/relay/qnn/op/convolution.cc
tests/python/relay/test_op_qnn_conv2d.py
tests/python/relay/test_pass_alter_op_layout.py

index a703d92..c2ae572 100644 (file)
@@ -211,6 +211,28 @@ class Layout : public NodeRef {
   }
 
   /*!
+   * \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
+   * \param dst_layout The dst layout to which current layout has to be expanded.
+   * \return The expanded Layout.
+   */
+  inline Layout ExpandPrimal(const Layout& dst_layout) {
+    Layout new_src_layout;
+    // 1) Find the axis which are missing in the current layout. Make them the prefix.
+    std::string new_src_layout_str = "";
+    for (auto dst_axis : dst_layout->axes) {
+      if (LayoutAxis::Get(dst_axis).IsPrimal()) {
+        if (!this->Contains(LayoutAxis::Get(dst_axis))) {
+          new_src_layout_str += dst_axis->var->name_hint;
+        }
+      }
+    }
+    // 2) Now, add the primal axis of the current layout.
+    new_src_layout_str += this->name();
+    new_src_layout = Layout(new_src_layout_str);
+    return new_src_layout;
+  }
+
+  /*!
    * \brief return the index of the input axis.
    *        If it is not found in the layout or the layout is undefined,
    *        return -1.
index 0002390..3f371f2 100644 (file)
@@ -37,6 +37,7 @@
 #include "../op_common.h"
 #include "../../../arithmetic/compute_expr.h"
 #include "../../pass/alter_op_layout.h"
+#include "../../pass/pattern_util.h"
 #include "transform.h"
 
 namespace tvm {
index 9142c0e..23a480b 100644 (file)
@@ -38,6 +38,7 @@
 #include <unordered_map>
 
 #include "alter_op_layout.h"
+#include "pattern_util.h"
 
 namespace tvm {
 namespace relay {
@@ -45,19 +46,35 @@ namespace relay {
 namespace alter_op_layout {
 
 // Make a transform CallNode
+/* Performs 2 operations
+ * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size.
+ *    For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
+ * 2) Call layout transform with new src layout.
+ */
 Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
-  if (src_layout.Equals(dst_layout)) { return raw; }
-  CHECK(src_layout.defined() && dst_layout.defined())
-    << "Cannot insert layout transform because there are undefined layouts";
-  CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
-    << "Cannot insert layout transform because there are inconvertible layouts: "
-    << src_layout << " v.s. " << dst_layout;
-  static auto &transform_op = Op::Get("layout_transform");
-  NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>();
-  attrs->src_layout = src_layout.name();
-  attrs->dst_layout = dst_layout.name();
-  Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs});
-  return std::move(transform);
+  if (src_layout.Equals(dst_layout)) {
+    return raw;
+  }
+
+  // 1) Check if the shape lengths are different. If yes, expand dims.
+  Expr input_expr = raw;
+  Layout new_src_layout = src_layout;
+  if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
+    int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
+    new_src_layout = src_layout.ExpandPrimal(dst_layout);
+    input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
+    if (new_src_layout.Equals(dst_layout)) {
+      return input_expr;
+    }
+  }
+
+  // 2) Insert layout transform on the transformed src.
+  CHECK(new_src_layout.defined() && dst_layout.defined())
+      << "Cannot insert layout transform because there are undefined layouts";
+  CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined())
+      << "Cannot insert layout transform because there are inconvertible layouts: "
+      << new_src_layout << " v.s. " << dst_layout;
+  return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
 }
 
 // Memorize layout transform so we can reuse internal transformed nodes
index 80593a5..350cede 100644 (file)
 
 #include <tvm/data_layout.h>
 #include <tvm/relay/expr.h>
+#include <string>
 
 namespace tvm {
 namespace relay {
 
 /*!
+ * \brief Returns a new layout where the subordinate factors are adjusted based on the tensor
+ *        shape.
+ * \param old_layout The old layout before any transformation.
+ * \param old_shape The shape of the original tensor.
+ * \return The adjusted Layout.
+ */
+inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
+                                       const Array<tvm::Expr>& old_shape) {
+  // For each subordinate axis
+  //   1) Find the corresponding dual axis.
+  //   2) Find the Index of this dual axis in old_layout.
+  //   3) Find the shape of the that axis in old_shape.
+  //   4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor.
+  std::string new_layout;
+  for (auto axis : src_layout->axes) {
+    if (!LayoutAxis::Get(axis).IsPrimal()) {
+      // 1) Find the corresponding dual axis
+      auto dual_axis = LayoutAxis::Get(axis).ToPrimal().name()[0];
+
+      // 2) Find the index of this dual axis in old_layout
+      int old_axis = old_layout.IndexOf(LayoutAxis::Get(dual_axis));
+
+      // 3) Find the shape of this index in old_shape
+      auto shape_val = old_shape[old_axis];
+
+      // 4) a) Check if this shape element is 1.
+      bool is_shape_one = false;
+      if (auto* shape_int = shape_val.as<IntImm>()) {
+        if (shape_int->value == 1) {
+          new_layout += "1";
+          is_shape_one = true;
+        }
+      }
+
+      // 4) b) If shape is not 1, retain the factor.
+      if (!is_shape_one) {
+        auto new_shape_val = src_layout.FactorOf(LayoutAxis::Get(dual_axis));
+        new_layout += std::to_string(new_shape_val);
+      }
+    }
+    new_layout += LayoutAxis::Get(axis).name();
+  }
+  return Layout(new_layout);
+}
+
+/*!
  * \brief Infer & correct function of node layout. See \p Layout for layout convention
  * \param attrs The attribute of the node.
  * \param new_in_layouts The layouts of input arguments after alter_op_layout.
@@ -111,28 +158,39 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
     int scalar = layouts[0].ndim() == 0 ? 0 : 1;
     return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
   } else {
-    // try to broadcast the tensors to the larger dimension
+    // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
+    // while transforming layout.
     int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
     int small_idx = 1 - large_idx;
     Layout ret = layouts[large_idx];
 
-    // extract common part
-    size_t i = layouts[large_idx].ndim();
-    for (; i != 0; --i) {
-      const auto& axis = layouts[large_idx][i-1];
-      if (!layouts[small_idx].Contains(axis.ToPrimal())) {
-        break;
-      }
-    }
-
-    Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
-    if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
-      // not convertible
-      return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
+    if (old_in_layouts[0].Equals(old_in_layouts[1])) {
+      // Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1]
+      // In this case, we might have NCHW16c coming for 1 operand. However, the other operand does
+      // not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the
+      // second operand. The following section of code walks through the layouts and shapes to
+      // perform that operation.
+      // a in NCHWC16c
+      // b in NHW1
+      // b = layout_transform(b) from NHW1 -> NCHW1c
+      // add(a, b)
+      auto old_small_shape = old_in_shapes[small_idx];
+      auto old_small_layout = old_in_layouts[small_idx];
+      auto new_small_layout =
+          AdjustSubordinateFactors(layouts[large_idx], old_small_layout, old_small_shape);
+      layouts.Set(small_idx, new_small_layout);
+    } else {
+      // Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
+      // while transforming the layout, we expand dims to make C go to NHWC, and then use the
+      // modified layout of the first operator to call the layout transform. E.g.
+      // a in NCHWC16c
+      // b in C
+      // b = expand_dims(b) from C -> NHWC
+      // b = layout_transform(b) from NHWC -> NCHW16c
+      // add(a, b)
+      layouts.Set(small_idx, ret);
     }
-
-    layouts.Set(small_idx, common_part);
-    return Array<Array<Layout> > {layouts, {ret}};
+    return Array<Array<Layout>>{layouts, {ret}};
   }
 }
 
index bf9621b..988b13c 100644 (file)
@@ -505,6 +505,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);
 
 Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
 
+Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);
+
 Expr StopFusion(Expr data);
 
 Expr CastHint(Expr data, DataType dtype);
index deac4e6..a73a658 100644 (file)
@@ -242,19 +242,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
   if (param->kernel_zero_point != 1) {
     multiplied_t2 = Multiply(zp_kernel, reduced_t2);
   }
-
-  // Replicate to go back to NHWC/NCHW. This is not necessarily needed, but it fails AlterOpLayout.
-  // We can remove this once AlterOpLayout refactoring completes -
-  // https://github.com/dmlc/tvm/issues/3670
-  Array<Integer> reps;
-  if (param->data_layout == "NCHW") {
-    reps = {1, out_channels, 1, 1};
-  } else if (param->data_layout == "NHWC") {
-    reps = {1, 1, 1, out_channels};
-  } else {
-    LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
-  }
-  return Tile(multiplied_t2, reps);
+  return multiplied_t2;
 }
 
 /*
index c8e479d..b4e8bfd 100644 (file)
@@ -607,6 +607,39 @@ def tflite_anistropic_strides():
     golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
     np.testing.assert_equal(qnn_output, golden_output)
 
+def broadcast_layout_test():
+    # Test broadcast support for NHWC layout.
+    data_shape = (1, 229, 229, 3) # NHWC
+    data_dtype = 'uint8'
+    kernel_shape = (7, 7, 3, 64) # HWIO
+    kernel_dtype = 'int8'
+    _, qnn_func = get_funcs(data_shape=data_shape,
+                            data_dtype=data_dtype,
+                            kernel_shape=kernel_shape,
+                            kernel_dtype=kernel_dtype,
+                            input_zero_point=8,
+                            kernel_zero_point=3,
+                            kernel_size=(7, 7),
+                            padding=(1, 1),
+                            strides=(1, 1),
+                            dilation=(1, 1),
+                            data_layout="NHWC",
+                            kernel_layout="HWIO",
+                            out_dtype="int32")
+    func = qnn_func['main'].body
+    bias = relay.var("bias", shape=(64,), dtype="int32")
+    bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")
+
+    # Check broadcast support on both lhs and rhs
+    func = relay.add(func, bias2)
+    func = relay.add(bias2, func)
+    func = relay.add(bias, func)
+    func = relay.add(func, bias)
+    func = relay.Function(relay.analysis.free_vars(func), func)
+    mod = relay.Module.from_expr(func)
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")
+
 if __name__ == "__main__":
     no_zero_point_test()
     input_zero_point_test()
@@ -620,3 +653,4 @@ if __name__ == "__main__":
     tflite_large_irregular_test()
     tflite_output_multiplier_greater_than_one()
     tflite_anistropic_strides()
+    broadcast_layout_test()
index 6b31eed..cc668d7 100644 (file)
@@ -134,7 +134,8 @@ def test_alter_layout():
                             kernel_layout="OIHW16i",
                             data_layout="NCHW16c")
         b = relay.expand_dims(bias, axis=1, num_newaxis=2)
-        b = relay.layout_transform(b, "CHW", "CHW16c")
+        b = relay.expand_dims(b, axis=0, num_newaxis=1)
+        b = relay.layout_transform(b, "NCHW", "NCHW16c")
         y = relay.add(y, b)
 
         y = relay.nn.relu(y)
@@ -304,8 +305,10 @@ def test_alter_layout_broadcast_op():
         weight = relay.var("weight")
         x = relay.layout_transform(x, "NCHW", "NCHW16c")
         bias = relay.expand_dims(bias, 1, 2)
-        bias = relay.layout_transform(bias, "CHW", "CHW16c")
-        scale = relay.layout_transform(scale, "CHW", "CHW16c")
+        bias = relay.expand_dims(bias, 0, 1)
+        bias = relay.layout_transform(bias, "NCHW", "NCHW16c")
+        scale = relay.expand_dims(scale, 0, 1)
+        scale = relay.layout_transform(scale, "NCHW", "NCHW16c")
         y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
                             data_layout="NCHW16c")
         y = relay.add(y, bias)          # test broadcasting to lhs