[TIR][BUILD] Remove buffer params from pass config. (#5652)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 23 May 2020 15:38:00 +0000 (08:38 -0700)
committerGitHub <noreply@github.com>
Sat, 23 May 2020 15:38:00 +0000 (08:38 -0700)
Buffer configurations can be passed during construction
and does not need to be part of the build config.

This is a refactor step to simplify the BuildConfig for the PassContext migration.

13 files changed:
include/tvm/target/target.h
python/tvm/driver/build_module.py
python/tvm/target/build_config.py
python/tvm/te/tensor_intrin.py
src/driver/driver_api.cc
src/target/target.cc
src/tir/ir/buffer.cc
tests/python/unittest/test_te_schedule.py
tests/python/unittest/test_te_schedule_ops.py
tests/python/unittest/test_te_schedule_tensorize.py
tests/python/unittest/test_te_tensor.py
topi/python/topi/cuda/tensor_intrin.py
topi/python/topi/x86/tensor_intrin.py

index c28b051..de48ac2 100644 (file)
@@ -178,17 +178,6 @@ TVM_DLL Target hexagon(const std::vector<std::string>& options = std::vector<std
 class BuildConfigNode : public Object {
  public:
   /*!
-   * \brief The data alignment to use when constructing buffers. If this is set to
-   * -1, then TVM's internal default will be used
-   */
-  int data_alignment = -1;
-  /*!
-   * \brief The offset factor to use when constructing buffers. If this is set to
-   * 0, then the offset field is not used.
-   */
-  int offset_factor = 0;
-
-  /*!
    * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
    * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
    */
@@ -217,9 +206,6 @@ class BuildConfigNode : public Object {
   /*! \brief List of passes to be injected into the low-level pipeline. */
   std::vector<std::pair<int, transform::Pass>> add_lower_pass;
 
-  /*! \brief Whether to dump the IR of each pass (only when building from python) */
-  bool dump_pass_ir = false;
-
   /*! \brief Whether to instrument loads and stores with check for out of the bounds. */
   bool instrument_bound_checkers = false;
 
@@ -233,8 +219,6 @@ class BuildConfigNode : public Object {
   bool disable_assert = false;
 
   void VisitAttrs(AttrVisitor* v) {
-    v->Visit("data_alignment", &data_alignment);
-    v->Visit("offset_factor", &offset_factor);
     v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
     v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
     v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
@@ -243,7 +227,6 @@ class BuildConfigNode : public Object {
     v->Visit("restricted_func", &restricted_func);
     v->Visit("detect_global_barrier", &detect_global_barrier);
     v->Visit("partition_const_loop", &partition_const_loop);
-    v->Visit("dump_pass_ir", &dump_pass_ir);
     v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
     v->Visit("disable_select_rewriting", &disable_select_rewriting);
     v->Visit("disable_vectorize", &disable_vectorize);
index 216cad9..97ed8d8 100644 (file)
@@ -56,7 +56,6 @@ def get_binds(args, compact=False, binds=None):
         The list of symbolic buffers of arguments.
     """
     binds = {} if binds is None else binds.copy()
-    cfg = BuildConfig.current()
     arg_list = []
     for x in args:
         if isinstance(x, tensor.Tensor):
@@ -66,9 +65,6 @@ def get_binds(args, compact=False, binds=None):
                 buf = tvm.tir.decl_buffer(
                     x.shape,
                     dtype=x.dtype,
-                    name=x.name,
-                    data_alignment=cfg.data_alignment,
-                    offset_factor=cfg.offset_factor,
                     buffer_type=buffer_type)
                 binds[x] = buf
                 arg_list.append(buf)
@@ -157,8 +153,6 @@ def lower(sch,
     """
     cfg = BuildConfig.current()
     add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
-    if cfg.dump_pass_ir:
-        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
     lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
     lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
     lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
index 538ee7d..a99797a 100644 (file)
@@ -45,11 +45,8 @@ class BuildConfig(Object):
         "unroll_explicit": True,
         "detect_global_barrier": False,
         "partition_const_loop": False,
-        "offset_factor": 0,
-        "data_alignment": -1,
         "restricted_func": True,
         "double_buffer_split_loop": 1,
-        "dump_pass_ir": False,
         "instrument_bound_checkers": False,
         "disable_select_rewriting": False,
         "disable_vectorize": False,
@@ -129,14 +126,6 @@ def build_config(**kwargs):
     partition_const_loop: bool, default=False
         Whether partition const loop
 
-    data_alignment: int, optional
-        The alignment of data pointer in bytes.
-        If -1 is passed, the alignment will be set to TVM's internal default.
-
-    offset_factor: int, default=0
-        The factor used in default buffer declaration.
-        If specified as 0, offset field is not used.
-
     restricted_func: bool, default=True
         Whether build restricted function.
         That is each buffer argument to the function are guaranteed
@@ -152,8 +141,6 @@ def build_config(**kwargs):
         phase contains an integer on which optimization pass we apply the pass.
         Additional lowering passes to be applied before make_api.
 
-    dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
-
     Returns
     -------
     config: BuildConfig
index c5c2afe..cd488a7 100644 (file)
@@ -20,7 +20,6 @@ import tvm.tir
 
 from tvm.runtime import Object, convert
 from tvm.ir import Range
-from tvm.target import BuildConfig
 from .tensor import PlaceholderOp
 
 from . import tensor as _tensor
@@ -68,7 +67,9 @@ class TensorIntrin(Object):
 def decl_tensor_intrin(op,
                        fcompute,
                        name="tensor_intrin",
-                       binds=None, scalar_params=None):
+                       binds=None,
+                       scalar_params=None,
+                       default_buffer_params=None):
     """Declare a tensor intrinsic function.
 
     Parameters
@@ -104,6 +105,9 @@ def decl_tensor_intrin(op,
     scalar_params: a list of variables used by op, whose values will be passed
                    as scalar_inputs when the tensor intrinsic is called.
 
+    default_buffer_params: Optional[dict]
+        Dictionary of buffer arguments to be passed when constructing a buffer.
+
     Returns
     -------
     intrin: TensorIntrin
@@ -122,12 +126,11 @@ def decl_tensor_intrin(op,
         if not isinstance(t.op, PlaceholderOp):
             raise ValueError("Do not yet support composition op")
 
-    cfg = BuildConfig.current()
+    default_buffer_params = {} if default_buffer_params is None else default_buffer_params
     for t in tensors:
         buf = (binds[t] if t in binds else
                tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
-                                   data_alignment=cfg.data_alignment,
-                                   offset_factor=cfg.offset_factor))
+                                   **default_buffer_params))
         binds_list.append(buf)
 
     if scalar_params:
index cdd9d54..ca1e122 100644 (file)
@@ -91,8 +91,7 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
 
   for (const auto& x : args) {
     if (out_binds->find(x) == out_binds->end()) {
-      auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, config->data_alignment,
-                                           config->offset_factor, compact);
+      auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact);
       out_binds->Set(x, buf);
       out_arg_list->push_back(buf);
     } else {
index 644ebdf..aac5a2b 100644 (file)
@@ -357,8 +357,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<BuildConfigNode>([](const ObjectRef& node, ReprPrinter* p) {
       auto* op = static_cast<const BuildConfigNode*>(node.get());
       p->stream << "build_config(";
-      p->stream << "data_alignment=" << op->data_alignment << ", ";
-      p->stream << "offset_factor=" << op->offset_factor << ", ";
       p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
       p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
       p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
@@ -367,7 +365,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "restricted_func=" << op->restricted_func << ", ";
       p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
       p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
-      p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
       p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
       p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
       p->stream << "disable_vectorize=" << op->disable_vectorize;
index 45b9680..8b98ed9 100644 (file)
@@ -34,7 +34,7 @@
 
 namespace tvm {
 namespace tir {
-// TODO(tqchen): change to floormod/div
+
 using IndexMod = tir::FloorModNode;
 using IndexDiv = tir::FloorDivNode;
 
index 9b8d406..2c851cc 100644 (file)
@@ -115,7 +115,6 @@ def test_fuse_with_split():
     assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations)
     assert tuple(s[T].leaf_iter_vars) == (xo, fused)
 
-@pytest.mark.xfail
 def test_fuse_with_out_of_order_axis():
     m = te.size_var('m')
     n = te.size_var('n')
@@ -125,9 +124,10 @@ def test_fuse_with_out_of_order_axis():
     s = te.create_schedule(T.op)
     y = T.op.axis[1]
     xo, xi = s[T].split(T.op.axis[0], factor=10)
-    fused = s[T].fuse(xo, y) # should throw here
 
-@pytest.mark.xfail
+    with pytest.raises(RuntimeError):
+            fused = s[T].fuse(xo, y) # should throw here
+
 def test_fuse_with_out_of_order_axis_with_reorder():
     m = te.size_var('m')
     n = te.size_var('n')
@@ -144,23 +144,21 @@ def test_fuse_with_out_of_order_axis_with_reorder():
     y = T.op.axis[1]
     xo, xi = s[T].split(T.op.axis[0], factor=10)
     s[T].reorder(y, xo, xi)
-    fused = s[T].fuse(y, xi) # should throw here
+
+    with pytest.raises(RuntimeError):
+        fused = s[T].fuse(y, xi) # should throw here
 
 def test_singleton():
-    print("test singleton")
     A = te.placeholder((), name='A')
     T = te.compute((), lambda : A() + 1)
     s = te.create_schedule(T.op)
-    print("test singleton fin1")
     fused = s[T].fuse()
     assert any(isinstance(x, tvm.te.schedule.Singleton) for x in s[T].relations)
     assert tuple(s[T].leaf_iter_vars) == (fused,)
     dump = pkl.dumps(s)
-    print("test singleton fin3")
     s_loaded = pkl.loads(dump)
-    print("test singleton fin2")
     assert isinstance(s_loaded, tvm.te.schedule.Schedule)
-    print("test singleton fin")
+
 
 def test_vectorize():
     m = te.size_var('m')
@@ -177,13 +175,14 @@ def test_vectorize():
     assert s[T].iter_var_attrs[xi].iter_type == UNROLL
     assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
 
-@pytest.mark.xfail
+
 def test_vectorize_commreduce():
     V = te.placeholder((128,), name='V')
     ax = te.reduce_axis((0, 128), name='ax')
     O = te.compute((1,), lambda _: te.sum(V[ax], axis=[ax]))
     s = te.create_schedule(O.op)
-    s[O].vectorize(ax) # should throw here
+    with pytest.raises(RuntimeError):
+        s[O].vectorize(ax) # should throw here
 
 def test_pragma():
     m = 100
@@ -271,8 +270,9 @@ def test_tensor_intrin_scalar_params():
         assert(sp[1] == w)
         return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
 
-    with tvm.target.build_config(offset_factor=1):
-      intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
+    intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w], default_buffer_params={
+        "offset_factor": 1
+    })
     assert intrin.op == z.op
     assert intrin.reduce_init is None
     assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
index 7cbf20e..3f93c77 100644 (file)
@@ -321,10 +321,9 @@ def intrin_gemv(m, n):
             "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, reset, update
 
-    with tvm.target.build_config(data_alignment=16,
-                          offset_factor=16):
-        return te.decl_tensor_intrin(z.op, intrin_func,
-                                      binds={w: Wb})
+    buffer_params = {"data_alignment": 16, "offset_factor": 16}
+    return te.decl_tensor_intrin(
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
 
 
 def test_schedule_tensor_compute1():
@@ -377,8 +376,9 @@ def intrin_vadd(n, cache_read=False, cache_write=False):
         ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
         return ib.get()
 
-    with tvm.target.build_config(offset_factor=16):
-        return te.decl_tensor_intrin(z.op, intrin_func, binds=binds)
+    return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={
+        "offset_factor": 16
+    })
 
 
 def test_schedule_tensor_compute2():
index ef5b3fd..5152235 100644 (file)
@@ -25,8 +25,8 @@ def intrin_vadd(n):
         xx, yy = ins
         zz = outs[0]
         return tvm.tir.call_packed("vadd", xx, yy, zz)
-    with tvm.target.build_config(offset_factor=16):
-        return te.decl_tensor_intrin(z.op, intrin_func)
+    buffer_params = {"offset_factor": 16}
+    return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params)
 
 def intrin_gemv(m, n):
     w = te.placeholder((m, n), name='w')
@@ -52,10 +52,9 @@ def intrin_gemv(m, n):
             "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, reset, update
 
-    with tvm.target.build_config(data_alignment=16,
-                          offset_factor=16):
-        return te.decl_tensor_intrin(z.op, intrin_func,
-                                      binds={w: Wb})
+    buffer_params = {"offset_factor": 16, "data_alignment": 16}
+    return te.decl_tensor_intrin(
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
 
 def intrin_gemv_no_reset(m, n):
     w = te.placeholder((m, n), name='w')
@@ -79,10 +78,10 @@ def intrin_gemv_no_reset(m, n):
             "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0])
         return body, None, update
 
-    with tvm.target.build_config(data_alignment=16,
-                          offset_factor=16):
-        return te.decl_tensor_intrin(z.op, intrin_func,
-                                      binds={w: Wb})
+
+    buffer_params = {"offset_factor": 16, "data_alignment": 16}
+    return te.decl_tensor_intrin(
+        z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params)
 
 
 def test_tensorize_vadd():
@@ -248,8 +247,9 @@ def test_tensorize_op():
             zz = outs[0]
             return tvm.tir.call_packed("op", xx, zz)
 
-        with tvm.target.build_config(offset_factor=2):
-            return te.decl_tensor_intrin(y.op, intrin_func)
+        return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={
+            "offset_factor": 2
+        })
 
     A = te.placeholder((5, 5), name='A')
     B = te.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
@@ -286,8 +286,7 @@ def test_tensorize_tensor_compute_op():
         def intrin_func(ins, outs):
             return tvm.tir.call_packed("multivadd")
 
-        with tvm.target.build_config():
-            return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
+        return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd")
 
     def intrin_vadd(n):
         dtype = 'float32'
@@ -297,9 +296,7 @@ def test_tensorize_tensor_compute_op():
         s = te.create_schedule(z.op)
 
         def create_buffer(t):
-            return tvm.tir.decl_buffer(t.shape, t.dtype,
-                                   name='W'+t.name,
-                                   offset_factor=16)
+            return tvm.tir.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16)
 
         def intrin_func(ins, outs):
             ib = tvm.tir.ir_builder.create()
@@ -307,11 +304,9 @@ def test_tensorize_tensor_compute_op():
                                     ins[0].access_ptr("r"), ins[1].access_ptr('r'),
                                     outs[0].access_ptr('wr')))
             return ib.get()
-
-        with tvm.target.build_config(offset_factor=16):
-            return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
-                                                                    y: create_buffer(y),
-                                                                    z: create_buffer(z)})
+        return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x),
+                                                                y: create_buffer(y),
+                                                                z: create_buffer(z)})
 
     # cache_read, cache_write
     M = 1024
index 5d3cbad..a8ab3cf 100644 (file)
@@ -117,8 +117,9 @@ def test_tensor_compute1():
             ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
             return ib.get()
 
-        with tvm.target.build_config(offset_factor=n):
-            return te.decl_tensor_intrin(z.op, intrin_func)
+        return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={
+            "offset_factor": n
+        })
 
     vadd = intrin_vadd(factor)
 
@@ -159,8 +160,8 @@ def test_tensor_compute2():
                 "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l)
             return body, reset, update
 
-        with tvm.target.build_config(offset_factor=n):
-            return te.decl_tensor_intrin(z.op, intrin_func)
+        return te.decl_tensor_intrin(z.op, intrin_func,
+                                     default_buffer_params={"offset_factor": n})
 
     vgemm = intrin_gemm(factor1, factor2, factor)
 
@@ -290,8 +291,8 @@ def test_tensor_pool():
             dout = outs[0]
             return tvm.tir.call_packed("op", dinp, dout)
 
-        with tvm.target.build_config(offset_factor=1):
-            return te.decl_tensor_intrin(P.op, intrin_func)
+        return te.decl_tensor_intrin(P.op, intrin_func,
+                                     default_buffer_params={"offset_factor": 1})
 
     A = te.placeholder((1, 64, 16, 16), name='A')
     P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0),
index f8fce34..3941c00 100644 (file)
@@ -69,14 +69,15 @@ def dp4a(x_scope='local', y_scope='local', z_scope='local'):
 
         return _instr(0), _instr(1), _instr(2) # body, reset, update
 
-    with tvm.target.build_config(data_alignment=4, offset_factor=1) as cfg:
-        scopes = {x: x_scope, y: y_scope, z: z_scope}
-        binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
-                                        data_alignment=cfg.data_alignment,
-                                        offset_factor=cfg.offset_factor,
-                                        scope=scopes[t]) for t in [x, y, z]}
-
-        return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
+    default_buffer_params = {
+        "data_alignment": 4, "offset_factor": 1
+    }
+    scopes = {x: x_scope, y: y_scope, z: z_scope}
+    binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name,
+                                    scope=scopes[t], **default_buffer_params) for t in [x, y, z]}
+
+    return te.decl_tensor_intrin(
+        z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params)
 
 
 def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype):
index 955b6b4..ee8d83d 100644 (file)
@@ -110,8 +110,10 @@ def dot_16x1x16_uint8_int8_int32_skylake():
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
 
-    with tvm.target.build_config(offset_factor=1, partition_const_loop=True):
-        return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
+    buffer_params = {"offset_factor" : 1}
+    return te.decl_tensor_intrin(
+        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
+        default_buffer_params=buffer_params)
 
 
 def dot_16x1x16_uint8_int8_int16():
@@ -191,9 +193,10 @@ def dot_16x1x16_uint8_int8_int16():
 
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
-
-    with tvm.target.build_config(offset_factor=1, partition_const_loop=True):
-        return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
+    buffer_params = {"offset_factor" : 1}
+    return te.decl_tensor_intrin(
+        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
+        default_buffer_params=buffer_params)
 
 
 def dot_16x1x16_uint8_int8_int32_cascadelake():
@@ -287,5 +290,7 @@ def dot_16x1x16_uint8_int8_int32_cascadelake():
         # body, reset, update
         return _instr(0), _instr(1), _instr(2)
 
-    with tvm.target.build_config(offset_factor=1, partition_const_loop=True):
-        return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
+    buffer_params = {"offset_factor" : 1}
+    return te.decl_tensor_intrin(
+        C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
+        default_buffer_params=buffer_params)