[Relay] InferCorrectLayout for strided_slice & min_num_branches option in CombinePara...
authorWuwei Lin <vincentl13x@gmail.com>
Tue, 9 Apr 2019 05:20:56 +0000 (13:20 +0800)
committerYizhi Liu <liuyizhi@apache.org>
Tue, 9 Apr 2019 05:20:56 +0000 (22:20 -0700)
* [Relay] InferCorrectLayout for strided_slice

* Add min_num_branches option to CombineParallelConv2D

* Return undef if original layout contains splitted axes

python/tvm/relay/ir_pass.py
src/relay/op/tensor/transform.cc
src/relay/pass/combine_parallel_conv2d.cc
tests/python/relay/test_pass_alter_op_layout.py
tests/python/relay/test_pass_combine_parallel_conv2d.py

index 8eb0adc..b3d323b 100644 (file)
@@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1):
     return _ir_pass.FuseOps(expr, opt_level)
 
 
-def combine_parallel_conv2d(expr):
-    """Fold multiple conv2d into one.
+def combine_parallel_conv2d(expr, min_num_branches=3):
+    """Combine multiple conv2d into one.
 
     Parameters
     ----------
     expr : tvm.relay.Expr
         The input expression.
 
+    min_num_branches : int
+        The minimum number of parallel branches when the transformation should be applied.
+
     Returns
     -------
     transformed_expr : tvm.relay.Expr
         Transformed expression
     """
-    return _ir_pass.CombineParallelConv2D(expr)
+    return _ir_pass.CombineParallelConv2D(expr, min_num_branches)
 
 
 def alter_op_layout(expr):
index 15eaceb..f86156b 100644 (file)
@@ -1722,6 +1722,64 @@ bool StridedSliceRel(const Array<Type>& types,
 }
 
 
+Array<Array<Layout> > StridedSliceInferCorrectLayout(
+    const Attrs& attrs,
+    const Array<Layout>& new_in_layouts,
+    const Array<Layout>& old_in_layouts,
+    const Array<Array<IndexExpr>>& old_in_shapes) {
+  CHECK(old_in_layouts.defined());
+  CHECK_EQ(old_in_layouts.size(), 1);
+  CHECK(old_in_shapes.defined());
+  CHECK_EQ(old_in_shapes.size(), 1);
+
+  auto layout = old_in_layouts[0];
+  if (layout.defined() && new_in_layouts.defined()) {
+    CHECK_EQ(new_in_layouts.size(), 1);
+    auto new_layout = new_in_layouts[0];
+    auto shape = old_in_shapes[0];
+
+    // NOTE: Discard "const" qualifier here.
+    auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+
+    Array<Integer> new_begin, new_end;
+
+    for (size_t i = 0; i < params->begin.size(); i++) {
+      const LayoutAxis& axis = layout[i];
+      if (!axis.IsPrimal()) {
+        // original layout that contains splitted axes is not supported
+        return {{Layout::Undef()}, {Layout::Undef()}};
+      }
+      auto factor = new_layout.FactorOf(axis);
+      if (factor == -1) {
+        new_begin.push_back(params->begin[i]);
+        new_end.push_back(params->end[i]);
+      } else {
+        if (params->strides.defined() && i < params->strides.size()) {
+          auto stride = params->strides[i];
+          // arbitrary stride is not supported
+          if (stride.defined() && stride->value != 1) {
+            return {{Layout::Undef()}, {Layout::Undef()}};
+          }
+        }
+        int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
+        int64_t end = params->end[i].defined() ? params->end[i]->value :
+            shape[i].as<IntImm>()->value;
+        if (begin % factor || end % factor) {
+          // transform to original layout
+          return {{Layout::Undef()}, {Layout::Undef()}};
+        }
+        new_begin.push_back(tvm::Integer(begin / factor));
+        new_end.push_back(tvm::Integer(end / factor));
+      }
+    }
+    layout = new_layout;
+    params->begin = new_begin;
+    params->end = new_end;
+  }
+  return {{layout}, {layout}};
+}
+
+
 // Positional relay function to create StridedSlice operator used by frontend FFI.
 Expr MakeStridedSlice(Expr data,
                       Array<Integer> begin,
@@ -1783,7 +1841,8 @@ Examples::
 .set_attrs_type_key("relay.attrs.StridedSliceAttrs")
 .add_type_rel("StridedSlice", StridedSliceRel)
 .set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+.set_attr<TOpPattern>("TOpPattern", kInjective)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
 
 
 // relay.split
index cb53698..cd7a852 100644 (file)
@@ -159,10 +159,15 @@ class BranchGroupFinder : private ExprVisitor {
 
 class ParallelConv2DCombiner {
  public:
+  explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
+  }
+
   Expr Combine(const Expr& expr) {
     auto groups = BranchGroupFinder().Find(expr);
     for (const Group& group : groups) {
-      if (group.size() < 2) continue;
+      if (group.size() < min_num_branches_) {
+        continue;
+      }
       CombineBranches(group);
     }
     return ExprSubst(expr, std::move(subst_map_));
@@ -170,6 +175,7 @@ class ParallelConv2DCombiner {
 
  private:
   std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
+  uint64_t min_num_branches_;
 
   std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
     int64_t num_filters = 0;  // number of filters of the transformed weight
@@ -343,11 +349,14 @@ class ParallelConv2DCombiner {
   }
 };
 
-Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
+/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
+Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
+  return ParallelConv2DCombiner(min_num_branches).Combine(expr);
+}
 
 TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-  *ret = CombineParallelConv2D(args[0]);
+  *ret = CombineParallelConv2D(args[0], args[1]);
 });
 
 }  // namespace relay
index 0f21288..f7a1c83 100644 (file)
@@ -472,6 +472,48 @@ def test_alter_layout_nchw_upsamping_op():
     assert(alpha_equal(a, b))
 
 
+def test_alter_layout_strided_slice():
+    """Test rewriting strided_slice during alter_iop_layout"""
+    def before():
+        x = relay.var("x", shape=(1, 32, 28, 28))
+        weight = relay.var('weight', shape=(32, 32, 3, 3))
+        y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
+        y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
+        y = relay.Function(free_vars(y), y)
+        return y
+
+    @register_alter_op_layout("nn.conv2d", level=109)
+    def alter_conv2d(attrs, inputs, tinfos):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs['data_layout'] = 'NCHW4c'
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    def expected():
+        x = relay.var("x", shape=(1, 32, 28, 28))
+        weight = relay.var("weight")
+        x = relay.layout_transform(x, "NCHW", "NCHW4c")
+        y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
+                            data_layout="NCHW4c")
+        y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")
+        y = relay.Function(free_vars(y), y)
+        return y
+
+    a = before()
+    a = infer_type(a)
+    a = canonicalize_ops(a)
+    a = infer_type(a)
+
+    a = alter_op_layout(a)
+    a = infer_type(a)
+
+    b = expected()
+    b = infer_type(b)
+
+    assert(alpha_equal(a, b))
+
+
 if __name__ == "__main__":
     test_alter_op()
     test_alter_return_none()
@@ -482,3 +524,4 @@ if __name__ == "__main__":
     test_alter_layout_scalar()
     test_alter_layout_concatenate()
     test_alter_layout_nchw_upsamping_op()
+    test_alter_layout_strided_slice()
index 0d6e1e3..3bb656b 100644 (file)
@@ -55,7 +55,7 @@ def test_combine_parallel_conv2d():
 
         y_before = before(x, w1, w2, w3, w4)
         y = relay.ir_pass.infer_type(y_before)
-        y = relay.ir_pass.combine_parallel_conv2d(y)
+        y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
         y = relay.ir_pass.infer_type(y)
         y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
         y_expected = relay.ir_pass.infer_type(y_expected)
@@ -102,7 +102,7 @@ def test_combine_parallel_conv2d_scale_relu():
         bias = relay.var("bias", shape=(channels2, 1, 1))
         y_before = before(x, w1, w2, scale1, scale2, bias)
         y = relay.ir_pass.infer_type(y_before)
-        y = relay.ir_pass.combine_parallel_conv2d(y)
+        y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
         y = relay.ir_pass.infer_type(y)
         y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
         y_expected = relay.ir_pass.infer_type(y_expected)
@@ -142,7 +142,7 @@ def test_combine_parallel_conv2d_scale():
         scale2 = relay.var("scale2", shape=(1,))
         y_before = before(x, w1, w2, scale1, scale2)
         y = relay.ir_pass.infer_type(y_before)
-        y = relay.ir_pass.combine_parallel_conv2d(y)
+        y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
         y = relay.ir_pass.infer_type(y)
         y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
         y_expected = relay.ir_pass.infer_type(y_expected)
@@ -179,7 +179,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
         w = relay.var("w", shape=(out_c, in_c, 1, 1))
         y_before = before(x, w, repeat)
         y = relay.ir_pass.infer_type(y_before)
-        y = relay.ir_pass.combine_parallel_conv2d(y)
+        y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
         y = relay.ir_pass.infer_type(y)
         y_expected = expected(x, w, out_c, repeat)
         y_expected = relay.ir_pass.infer_type(y_expected)