[CODEGEN] Support cuda tensorcore subbyte int data type in auto tensorcore (#4546)
authorOrion34C <vcitory34@gmail.com>
Fri, 21 Feb 2020 02:43:45 +0000 (10:43 +0800)
committerGitHub <noreply@github.com>
Fri, 21 Feb 2020 02:43:45 +0000 (21:43 -0500)
* support cuda tensorcore subbyte int data type in auto tensorcore

* add lisence

* pass cpplint

* fix code review comments

* merge the int4/int1 codegen tutorial into the existing auto tensorcore tutorial

* using master's new API

* disable tuning when cuda is not enabled

* address cr comment

* do not run the tuning

* fix test failure

* fix cpplint error

* fix bool type reduction bug

* 1. fix a index bug 2. fix returned bytes value of int1/int4/uint4

* fix typo

include/tvm/runtime/data_type.h
include/tvm/tir/expr.h
src/runtime/ndarray.cc
src/target/source/codegen_c.cc
src/target/source/codegen_cuda.cc
src/tir/pass/arg_binder.cc
src/tir/pass/infer_fragment.cc
src/tir/pass/tensor_core.cc
tutorials/optimize/opt_matmul_auto_tensorcore.py

index 7e0ef49..e6f5e55 100644 (file)
@@ -230,7 +230,12 @@ class DataType {
 inline int GetVectorBytes(DataType dtype) {
   int data_bits = dtype.bits() * dtype.lanes();
   // allow bool to exist
-  if (dtype == DataType::Bool()) return 1;
+  if (dtype == DataType::Bool() ||
+      dtype == DataType::Int(4) ||
+      dtype == DataType::UInt(4) ||
+      dtype == DataType::Int(1)) {
+    return 1;
+  }
   CHECK_EQ(data_bits % 8, 0U)
       << "Need to load/store by multiple of bytes";
   return data_bits / 8;
index 756907b..7d49789 100644 (file)
@@ -1261,6 +1261,18 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
  */
 constexpr const char* tvm_mma_sync = "tvm_mma_sync";
 /*!
+ * \brief tvm intrinsic for tensor core bmma_sync operators.
+ *
+ *  void tvm_bmma_sync(Var fragment_d, Expr index_d,
+ *                     Var fragment_a, Expr index_a,
+ *                     Var fragment_b, Expr index_b,
+ *                     Var fragment_c, Expr index_c) {
+ *    nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
+ *                           fragment_b[index_b], fragment_c[index_c]);
+ *  }
+ */
+constexpr const char* tvm_bmma_sync = "tvm_bmma_sync";
+/*!
  * \brief tvm intrinsic for tensor core fill_fragment operators.
  *
  *  void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
index f4160cc..91002c9 100644 (file)
@@ -44,7 +44,12 @@ inline void VerifyDataType(DLDataType dtype) {
   } else {
     // allow uint1 as a special flag for bool.
     if (dtype.bits == 1 && dtype.code == kDLUInt) return;
-    CHECK_EQ(dtype.bits % 8, 0);
+    // allow int1/uint4/int4
+    else if (dtype.bits == 1 && dtype.code == kDLInt) return;
+    else if (dtype.bits == 4 && dtype.code == kDLUInt) return;
+    else if (dtype.bits == 4 && dtype.code == kDLInt) return;
+    else
+      CHECK_EQ(dtype.bits % 8, 0);
   }
   CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
 }
index 7f89307..9102055 100644 (file)
@@ -170,8 +170,13 @@ std::string CodeGenC::GetBufferRef(
     } else {
       os << vid;
     }
-    os << '[';
+    os << "[(";
     PrintExpr(index, os);
+    os << ")";
+    if (t.bits() == 4 ||
+        (t.bits() == 1 && t.is_int())) {
+      os << " / " << (32 / t.bits());
+    }
     os << ']';
   } else {
     // Buffer declared as vector type.
@@ -205,8 +210,13 @@ std::string CodeGenC::GetBufferRef(
       PrintType(t.element_of(), os);
       os << "*)";
     }
-    os << vid << " + ";
+    os << vid << " + (";
     PrintExpr(index, os);
+    os << ")";
+    if (t.bits() == 4 ||
+        (t.bits() == 1 && t.is_int())) {
+      os << " / " << (32 / t.bits());
+    }
     os << "))[0]";
   }
   return os.str();
index 889d8b6..d5cab6e 100644 (file)
@@ -144,6 +144,37 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
       }
     }
     switch (t.bits()) {
+      case 1: {
+        if (t.lanes() == 1) {
+          os << "int"; return;
+        } else if (t.lanes() == 8) {
+          os << "int8_t"; return;
+        } else if (t.lanes() == 16) {
+          os << "int16_t"; return;
+        } else if (t.lanes() == 32) {
+          os << "int"; return;
+        } else {
+          LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
+        }
+      }
+      case 4: {
+        if (t.lanes() == 1) {
+          os << "int"; return;
+        } else if (t.lanes() == 4) {
+          os << "int16_t"; return;
+        } else if (t.lanes() == 8) {
+          // directly 8 4-bit int in integer.
+          os << "int"; return;
+        } else if (t.lanes() == 16) {
+          os << "int2"; return;
+        } else if (t.lanes() == 32) {
+          os << "int4"; return;
+        } else if (t.lanes() == 64) {
+          os << "int8"; return;
+        } else {
+          LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
+        }
+      }
       case 8: {
         if (t.lanes() == 4) {
           // directly 4 8 bit int in integer.
@@ -182,7 +213,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
           os << "long"; break;
         }
       }
-      case 1: os << "int"; break;
       default: fail = true; break;
     }
     if (!fail && lanes == 1) {
@@ -371,6 +401,16 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
       this->PrintExpr(op->args[i * 2 + 1], os);
       os << "]" << ((i < 3) ? ", ": ")");
     }
+  } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
+    need_mma_h_ = true;
+    CHECK_EQ(op->args.size(), 8U);
+    os << "nvcuda::wmma::bmma_sync(";
+    for (int i = 0; i < 4; ++i) {
+      this->PrintExpr(op->args[i * 2], os);
+      os << "[";
+      this->PrintExpr(op->args[i * 2 + 1], os);
+      os << "]" << ((i < 3) ? ", ": ")");
+    }
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
@@ -410,8 +450,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
       if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
         CHECK(op->dtype == DataType::Float(16) ||
               op->dtype == DataType::Int(8) ||
-              op->dtype == DataType::UInt(8))
-          << "Matrix_a and matrix_b only support half or char or unsigned char type for now";
+              op->dtype == DataType::UInt(8) ||
+              op->dtype == DataType::Int(4) ||
+              op->dtype == DataType::UInt(4) ||
+              op->dtype == DataType::Int(1))
+          << "Matrix_a and matrix_b only support half or char or unsigned char "
+          << "or uint4 or int4 or int1 type for now";
       } else {
         CHECK(op->dtype == DataType::Float(16) ||
               op->dtype == DataType::Float(32) ||
@@ -425,6 +469,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
       stream << ' ';
       PrintType(op->dtype, stream);
     }
+    if ((op->dtype == DataType::Int(4) ||
+         op->dtype == DataType::UInt(4) ||
+         op->dtype == DataType::Int(1)) && scope == "shared") {
+      constant_size = constant_size / (32 / op->dtype.bits());
+    }
     stream << ' '<< vid << '['
            << constant_size << "];\n";
   }
@@ -552,6 +601,24 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
   std::stringstream type;
   PrintType(t, type);
   std::string shape_str = fragment_shapes[variable];
+  if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
+    type.str(std::string());
+    if (t.is_int()) {
+      if (t.bits() == 4) {
+        type << "nvcuda::wmma::experimental::precision::s4";
+      } else if (t.bits() == 1) {
+        type << "nvcuda::wmma::experimental::precision::b1";
+      } else {
+        LOG(FATAL) << "Unhandled interger type for wmma fragment!";
+      }
+    } else if (t.is_uint()) {
+      if (t.bits() == 4) {
+        type << "nvcuda::wmma::experimental::precision::u4";
+      } else {
+        LOG(FATAL) << "Unhandled interger type for wmma fragment!";
+      }
+    }
+  }
   if (scope == "wmma.matrix_a") {
     need_mma_h_ = true;
     std::string layout_str = fragment_layouts[variable];
index bd35c76..30542ea 100644 (file)
@@ -184,7 +184,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
                IntImm(DataType::UInt(8), dtype.bits()) &&
                TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
                IntImm(DataType::UInt(16), dtype.lanes()));
-  asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
+  if (!(dtype == DataType::Int(4) ||
+        dtype == DataType::UInt(4) ||
+        dtype == DataType::Int(1))) {
+    asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
+  }
   // data field
   if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
             arg_name + ".data", true)) {
@@ -201,6 +205,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
   init_nest_.emplace_back(LetStmtNode::make(
       v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
   for (size_t k = 0; k < buffer->shape.size(); ++k) {
+    if (dtype == DataType::Int(4) ||
+        dtype == DataType::UInt(4) ||
+        dtype == DataType::Int(1)) {
+      break;
+    }
     std::ostringstream field_name;
     field_name << v_shape->name_hint << '[' << k << ']';
     Bind_(buffer->shape[k],
index 0cb1b96..608945a 100644 (file)
@@ -138,7 +138,8 @@ class FragmentChecker : public StmtExprVisitor {
   void VisitExpr_(const CallNode* op) final {
     StmtExprVisitor::VisitExpr_(op);
     // Check shape when calling tvm_mma_sync
-    if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
+    if (op->is_intrinsic(intrinsic::tvm_mma_sync) ||
+        op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
       CHECK_EQ(op->args.size(), 8U);
       const VarNode* buffer_var_d = op->args[0].as<VarNode>();
       const VarNode* buffer_var_a = op->args[2].as<VarNode>();
index 6a5e015..88f7496 100644 (file)
@@ -199,7 +199,11 @@ class MMAMatcher: public StmtVisitor {
     BufferInfo buffer_a;
     if (!check_local_buffer_(load_a, &buffer_a)
         || !(buffer_a.dtype == DataType::Float(16) ||
-             buffer_a.dtype == DataType::Int(8))) {
+             buffer_a.dtype == DataType::Int(8) ||
+             buffer_a.dtype == DataType::UInt(8) ||
+             buffer_a.dtype == DataType::Int(4) ||
+             buffer_a.dtype == DataType::UInt(4) ||
+             buffer_a.dtype == DataType::Int(1))) {
       return false;
     }
 
@@ -208,7 +212,11 @@ class MMAMatcher: public StmtVisitor {
     BufferInfo buffer_b;
     if (!check_local_buffer_(load_b, &buffer_b)
         || !(buffer_b.dtype == DataType::Float(16) ||
-             buffer_b.dtype == DataType::Int(8))) {
+             buffer_b.dtype == DataType::Int(8) ||
+             buffer_b.dtype == DataType::UInt(8) ||
+             buffer_b.dtype == DataType::Int(4) ||
+             buffer_a.dtype == DataType::UInt(4) ||
+             buffer_a.dtype == DataType::Int(1))) {
       return false;
     }
 
@@ -736,6 +744,17 @@ class BufferAnalyser : public StmtExprVisitor {
         warp_tile_.k == 16) {
       return true;
     }
+    if (warp_tile_.m == 8 &&
+        warp_tile_.n == 8 &&
+        warp_tile_.k == 32) {
+      return true;
+    }
+    if (warp_tile_.m == 8 &&
+        warp_tile_.n == 8 &&
+        warp_tile_.k == 128) {
+      return true;
+    }
+
     return false;
   }
 
@@ -869,18 +888,29 @@ class TensorCoreIRMutator : public StmtExprMutator {
       ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();
 
       auto mma_sync_call =
-        [&buffer_node_a, &buffer_node_b]
+        [&buffer_node_a, &buffer_node_b, &ca, &cb]
         (const Buffer &buffer) {
           Buffer buffer_a(buffer_node_a);
           Buffer buffer_b(buffer_node_b);
-          return EvaluateNode::make(
-                  CallNode::make(DataType::Handle(),
-                        intrinsic::tvm_mma_sync,
-                        {buffer->data, buffer->elem_offset,
-                        buffer_a->data, buffer_a->elem_offset,
-                        buffer_b->data, buffer_b->elem_offset,
-                        buffer->data, buffer->elem_offset},
-                        CallNode::Intrinsic));
+          if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
+            return EvaluateNode::make(
+                    CallNode::make(DataType::Handle(),
+                          intrinsic::tvm_bmma_sync,
+                          {buffer->data, buffer->elem_offset,
+                          buffer_a->data, buffer_a->elem_offset,
+                          buffer_b->data, buffer_b->elem_offset,
+                          buffer->data, buffer->elem_offset},
+                          CallNode::Intrinsic));
+          } else {
+            return EvaluateNode::make(
+                    CallNode::make(DataType::Handle(),
+                          intrinsic::tvm_mma_sync,
+                          {buffer->data, buffer->elem_offset,
+                          buffer_a->data, buffer_a->elem_offset,
+                          buffer_b->data, buffer_b->elem_offset,
+                          buffer->data, buffer->elem_offset},
+                          CallNode::Intrinsic));
+          }
         };
 
       auto call_add_c =
index eb6501b..f7cdae2 100644 (file)
@@ -56,6 +56,8 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
       out_type = 'float'
     elif dtype == 'int8':
       out_type = 'int'
+    elif dtype == 'int4' or dtype == 'int1':
+      out_type = 'int'
     if (layout == 'NN'):
       return tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k))
     if (layout == 'NT'):
@@ -123,6 +125,12 @@ def test_gemm(N, L, M, dtype, layout):
     if dtype == 'int8':
       factor = 32
       offset = 16
+    elif dtype == 'int4':
+      factor = 64
+      offset = 32
+    elif dtype == 'int1':
+      factor = 256
+      offset = 128
 
     # create cache stages
     AA = s.cache_read(A, "shared", [C])
@@ -139,9 +147,9 @@ def test_gemm(N, L, M, dtype, layout):
     cfg = autotvm.get_config()
 
     cfg.define_knob("bx", [2, 4, 8])
-    cfg.define_knob("by", [16, 32, 64])
-    cfg.define_knob("step_k", [8, 16, 32])
-    cfg.define_knob("v", [4, 8])
+    cfg.define_knob("by", [8, 16, 32, 64])
+    cfg.define_knob("step_k", [1, 2, 4, 8, 16, 32])
+    cfg.define_knob("v", [4, 8, 16, 32])
     by = cfg['by'].val
     bx = cfg['bx'].val
     step_k = cfg['step_k'].val
@@ -150,9 +158,17 @@ def test_gemm(N, L, M, dtype, layout):
     # thread tile
     TX = 8
     TY = 1
+    if dtype == 'int4' or dtype == 'int1':
+      TX = 2
     # warp tile
     warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
-    warp_tile_k = 16 # it must be 16
+    warp_tile_k = 16 # it must be 16 for fp16/int8 data type
+    if dtype == 'int4':
+      warp_tile_m = 8
+      warp_tile_k = 32
+    elif dtype == 'int1':
+      warp_tile_m = 8
+      warp_tile_k = 128
     # block tile
     tile_x = bx * TX
     tile_y = by * TY
@@ -219,6 +235,10 @@ def test_gemm(N, L, M, dtype, layout):
 # and run the kernel to compare with numpy to check whether the results are correct.
 
 # check whether the gpu has tensorcore
+if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+  print("skip because cuda is not enabled..")
+  sys.exit(0)
+
 ctx = tvm.gpu()
 if not nvcc.have_tensorcore(ctx.compute_version):
   print('the gpu has no tensorcore, skipping...')
@@ -234,6 +254,15 @@ if len(sys.argv) >= 5:
 if len(sys.argv) >= 6:
   layout = sys.argv[5]
 
+# check whether current gpu arch support support current dtype's wmma codegen
+cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
+major, minor= nvcc.parse_compute_version(cuda_compute_capability)
+if dtype == 'int8':
+  assert(major == 7 and minor >= 2)
+elif dtype == 'int4' or dtype == 'int1':
+  # int4/int1 only support layout TN
+  assert(major == 7 and minor == 5 and layout == 'TN')
+
 def tune_and_evaluate(M, N, L, dtype, layout):
   task = autotvm.task.create(test_gemm, args=(N, L, M, dtype, layout), target='cuda')
   print(task.config_space)
@@ -305,6 +334,42 @@ def tune_and_evaluate(M, N, L, dtype, layout):
       c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
     elif (layout == "TT"):
       c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
+  elif dtype == 'int4':
+    c_np_type = np.int32
+    a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
+    b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
+    # "TN"
+    c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
+    a_np = np.zeros(shape=(N, int(L/8)), dtype = np.int32)
+    b_np = np.zeros(shape=(M, int(L/8)), dtype = np.int32)
+    # a_np --> col_major
+    for i in range(N):
+      for j in range(int(L/8)):
+        for k in range(8):
+          a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
+
+    # b_np --> row_major
+    for i in range(M):
+      for j in range(int(L/8)):
+        for k in range(8):
+          b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xf) << ((7 - k) * 4))
+  elif dtype == 'int1':
+    c_np_type = np.int32
+    a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
+    b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
+    # "TN"
+    c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
+    a_np = np.zeros(shape=(N, int(L/32)), dtype = np.int32)
+    b_np = np.zeros(shape=(M, int(L/32)), dtype = np.int32)
+    for i in range(N):
+      for j in range(int(L/32)):
+        for k in range(32):
+          a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xf) << (31 - k))
+
+    for i in range(M):
+      for j in range(int(L/32)):
+        for k in range(32):
+          b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xf) << (31 - k))
 
   c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
   a_tvm = tvm.nd.array(a_np, ctx=ctx)