[TOPI] Improve CUDA softmax scheduling (#5600)
authorWei Pan <60017475+wpan11nv@users.noreply.github.com>
Mon, 25 May 2020 16:44:57 +0000 (09:44 -0700)
committerGitHub <noreply@github.com>
Mon, 25 May 2020 16:44:57 +0000 (09:44 -0700)
- Do not use multiple kernels

- Schedule with warp reductions

- Fixed a bug on the lower warp memory pass

- Fixed warp shuffle intrinsics for the nvptx backend.

Signed-off-by: Wei Pan <weip@nvidia.com>
src/target/llvm/codegen_llvm.cc
src/target/llvm/llvm_common.h
src/tir/transforms/lower_warp_memory.cc
tests/python/unittest/test_tir_transform_lower_warp_memory.py
topi/python/topi/cuda/softmax.py

index b43e988..5c2c41a 100644 (file)
@@ -736,7 +736,40 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type
 #endif  // TVM_LLVM_VERSION
 }
 
+// Check if this is a warp shuffle intrinsic call and match its
+// corresponding nvvm intrinsic. Return true if the match is successful.
+static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
+  // Only 32 bit data type is supported.
+  if (op->dtype.is_vector() || op->dtype.bits() != 32) {
+    return false;
+  }
+
+  // Intrinsic lookup table.
+  // It is difficult to emit _sync verion that works on Pascal.
+  // We ignore the mask and only emit the non-sync version for nvptx.
+  llvm::Intrinsic::ID ids[] = {
+      llvm::Intrinsic::nvvm_shfl_idx_i32,  llvm::Intrinsic::nvvm_shfl_idx_f32,
+      llvm::Intrinsic::nvvm_shfl_up_i32,   llvm::Intrinsic::nvvm_shfl_up_f32,
+      llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};
+
+  int offset = 0;
+  if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
+    offset = 0;
+  } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
+    offset = 2;
+  } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
+    offset = 4;
+  } else {
+    return false;
+  }
+
+  *id = ids[offset + op->dtype.is_float()];
+  return true;
+}
+
 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
+  llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
+
   if (op->is_intrinsic("llvm_intrin")) {
     CHECK_GE(op->args.size(), 2U);
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
@@ -781,6 +814,25 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
     }
   } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
     return CreateStorageSync(op);
+  } else if (GetWarpShuffleIntrinsic(op, &id)) {
+    std::vector<llvm::Value*> arg_value;
+    std::vector<llvm::Type*> arg_type;
+    // Ignore the first mask operand and remove the last
+    // redundant warp_size..
+    size_t n_args = op->args.size() - 1;
+    for (size_t i = 1; i < n_args; ++i) {
+      arg_value.push_back(MakeValue(op->args[i]));
+      arg_type.push_back(arg_value.back()->getType());
+    }
+    llvm::Type* return_type = arg_type[0];
+    llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
+    return builder_->CreateCall(func, arg_value);
+  } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
+    // Only nvptx target may keep this intrinsic at this point.
+    // PTX assembly: asm "activemask.b32 r1;"
+    auto fty = llvm::FunctionType::get(t_int32_, false);
+    auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
+    return builder_->CreateCall(val);
   } else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
     const LoadNode* l = op->args[0].as<LoadNode>();
     CHECK(op->args.size() == 1 && l);
index 49389fe..529ee74 100644 (file)
@@ -28,6 +28,7 @@
 #include <llvm/Analysis/TargetTransformInfo.h>
 #include <llvm/Bitcode/BitcodeWriter.h>
 #include <llvm/ExecutionEngine/MCJIT.h>
+#include <llvm/IR/InlineAsm.h>
 #include <llvm/IR/Intrinsics.h>
 #include <llvm/IR/Value.h>
 #include <llvm/Support/SourceMgr.h>
index 4c8dec0..91879b6 100644 (file)
@@ -213,9 +213,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
     alloc_size *= op->dtype.lanes();
     std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
     warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
-    CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
-        << "Warp memory must be multiple of the extent of threadIdx.x";
-    warp_group_ = alloc_size / (width_ * warp_coeff_);
+
+    // Align the local memory size. The number of elements may not
+    // be a multiple of width_ * warp_coeff_; round it up.
+    int factor = width_ * warp_coeff_;
+    warp_group_ = (alloc_size + (factor - 1)) / factor;
+    alloc_size = warp_group_ * factor;
+
     return AllocateNode::make(op->buffer_var, op->dtype,
                               {make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
                               this->VisitStmt(op->body));
index c3cf289..ce9dd56 100644 (file)
@@ -218,9 +218,45 @@ def test_lower_warp_memory_cuda_2_buffers():
     check_cuda("float32")
     check_cuda("float16")
 
+def test_lower_warp_memory_roundup():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    def check(m):
+        A = te.placeholder((m,), name='A')
+        B = te.compute((m,), lambda i: A[i] + 1, name='B')
+
+        with tvm.target.create("cuda"):
+            s = te.create_schedule(B.op)
+            xo, xi = s[B].split(B.op.axis[0], factor=32)
+            tx = te.thread_axis("threadIdx.x")
+            s[B].bind(xo, te.thread_axis("blockIdx.x"))
+            s[B].bind(xi, tx)
+
+            AA = s.cache_read(A, "warp", [B])
+            _, yi = s[AA].split(s[AA].op.axis[0], factor=32)
+            s[AA].bind(yi, tx)
+            s[AA].compute_at(s[B], xo)
+
+            ctx = tvm.gpu(0)
+            func = tvm.build(s, [A, B], "cuda")
+            A_np = np.random.uniform(size=(m,)).astype(A.dtype)
+            B_np = np.zeros(shape=(m,)).astype(B.dtype)
+            A_nd = tvm.nd.array(A_np, ctx)
+            B_nd = tvm.nd.array(B_np, ctx)
+            func(A_nd, B_nd)
+            B_np = A_np + 1
+            tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)
+
+    check(m=31)
+    check(m=32)
+    check(m=33)
+
 if __name__ == "__main__":
     test_lower_warp_memory_local_scope()
     test_lower_warp_memory_correct_indices()
     test_lower_warp_memory_cuda_end_to_end()
     test_lower_warp_memory_cuda_half_a_warp()
     test_lower_warp_memory_cuda_2_buffers()
+    test_lower_warp_memory_roundup()
index 62c437a..6142f48 100644 (file)
 # under the License.
 # pylint: disable=invalid-name, unused-variable, trailing-whitespace
 """Schedule for softmax operator"""
+from tvm import target as target_
 from tvm import te
 from tvm.contrib import cudnn
 from .. import generic
 from .injective import schedule_injective_from_existing
 
-
 def schedule_softmax(outs):
     """Schedule for softmax op.
 
@@ -39,6 +39,7 @@ def schedule_softmax(outs):
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
     softmax = outs[0]
+    tgt = target_.Target.current(allow_none=False)
 
     op_tag = softmax.op.tag
     if op_tag == 'softmax_output':
@@ -53,6 +54,14 @@ def schedule_softmax(outs):
         raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
                          Got {0}'.format(op_tag))
 
+    # The nvptx backend only supports 32-bits warp shuffle instructions.
+    #
+    # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
+    def sched_warp_softmax():
+        if tgt.target_name == "nvptx":
+            return softmax.dtype == "float32" or softmax.dtype == "int32"
+        return True
+
     if len(softmax.shape) > 2:
         ops = [max_elem.op, expsum.op, softmax.op]
         if exp is not None:
@@ -60,6 +69,46 @@ def schedule_softmax(outs):
 
         for op in ops:
             s = schedule_injective_from_existing(s, op.output(0))
+
+    elif sched_warp_softmax():
+        # A warp of 32 threads performs a row reduction.
+        num_thread = tgt.thread_warp_size
+        block_x = te.thread_axis("blockIdx.x")
+        thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
+
+        # (4) softmax
+        xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
+        _, xii = s[softmax].split(xi, factor=4)
+        s[softmax].vectorize(xii)
+        s[softmax].bind(xo, thread_x)
+        s[softmax].bind(softmax.op.axis[0], block_x)
+
+        # (3) expsum
+        k = expsum.op.reduce_axis[0]
+        ko, _ = s[expsum].split(k, nparts=num_thread)
+        s[expsum].bind(ko, thread_x)
+        s[expsum].compute_at(s[softmax], xo)
+
+        # (2) exp
+        if exp is not None:
+            xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
+            _, xii = s[exp].split(xi, factor=4)
+            s[exp].vectorize(xii)
+            s[exp].bind(xo, thread_x)
+            s[exp].compute_at(s[expsum], expsum.op.axis[0])
+            s[exp].compute_at(s[softmax], softmax.op.axis[0])
+            s[exp].set_scope("warp")
+
+        # (1) max_elem
+        k = max_elem.op.reduce_axis[0]
+        ko, _ = s[max_elem].split(k, nparts=num_thread)
+        s[max_elem].bind(ko, thread_x)
+        if exp is not None:
+            s[max_elem].compute_at(s[exp], xo)
+        else:
+            s[max_elem].bind(ko, thread_x)
+            s[max_elem].bind(max_elem.op.axis[0], block_x)
+
     else:
         num_thread = 64
         block_x = te.thread_axis("blockIdx.x")