#endif // TVM_LLVM_VERSION
}
+// Check if this is a warp shuffle intrinsic call and match its
+// corresponding nvvm intrinsic. Return true if the match is successful.
+static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) {
+ // Only 32 bit data type is supported.
+ if (op->dtype.is_vector() || op->dtype.bits() != 32) {
+ return false;
+ }
+
+ // Intrinsic lookup table.
+ // It is difficult to emit _sync verion that works on Pascal.
+ // We ignore the mask and only emit the non-sync version for nvptx.
+ llvm::Intrinsic::ID ids[] = {
+ llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32,
+ llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32,
+ llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32};
+
+ int offset = 0;
+ if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
+ offset = 0;
+ } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
+ offset = 2;
+ } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
+ offset = 4;
+ } else {
+ return false;
+ }
+
+ *id = ids[offset + op->dtype.is_float()];
+ return true;
+}
+
llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
+ llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic;
+
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
}
} else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) {
return CreateStorageSync(op);
+ } else if (GetWarpShuffleIntrinsic(op, &id)) {
+ std::vector<llvm::Value*> arg_value;
+ std::vector<llvm::Type*> arg_type;
+ // Ignore the first mask operand and remove the last
+ // redundant warp_size..
+ size_t n_args = op->args.size() - 1;
+ for (size_t i = 1; i < n_args; ++i) {
+ arg_value.push_back(MakeValue(op->args[i]));
+ arg_type.push_back(arg_value.back()->getType());
+ }
+ llvm::Type* return_type = arg_type[0];
+ llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type);
+ return builder_->CreateCall(func, arg_value);
+ } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) {
+ // Only nvptx target may keep this intrinsic at this point.
+ // PTX assembly: asm "activemask.b32 r1;"
+ auto fty = llvm::FunctionType::get(t_int32_, false);
+ auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true);
+ return builder_->CreateCall(val);
} else if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const LoadNode* l = op->args[0].as<LoadNode>();
CHECK(op->args.size() == 1 && l);
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/MCJIT.h>
+#include <llvm/IR/InlineAsm.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/Value.h>
#include <llvm/Support/SourceMgr.h>
alloc_size *= op->dtype.lanes();
std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
- CHECK_EQ(alloc_size % (width_ * warp_coeff_), 0)
- << "Warp memory must be multiple of the extent of threadIdx.x";
- warp_group_ = alloc_size / (width_ * warp_coeff_);
+
+ // Align the local memory size. The number of elements may not
+ // be a multiple of width_ * warp_coeff_; round it up.
+ int factor = width_ * warp_coeff_;
+ warp_group_ = (alloc_size + (factor - 1)) / factor;
+ alloc_size = warp_group_ * factor;
+
return AllocateNode::make(op->buffer_var, op->dtype,
{make_const(DataType::Int(32), alloc_size / width_)}, op->condition,
this->VisitStmt(op->body));
check_cuda("float32")
check_cuda("float16")
+def test_lower_warp_memory_roundup():
+ if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+ print("skip because cuda is not enabled..")
+ return
+
+ def check(m):
+ A = te.placeholder((m,), name='A')
+ B = te.compute((m,), lambda i: A[i] + 1, name='B')
+
+ with tvm.target.create("cuda"):
+ s = te.create_schedule(B.op)
+ xo, xi = s[B].split(B.op.axis[0], factor=32)
+ tx = te.thread_axis("threadIdx.x")
+ s[B].bind(xo, te.thread_axis("blockIdx.x"))
+ s[B].bind(xi, tx)
+
+ AA = s.cache_read(A, "warp", [B])
+ _, yi = s[AA].split(s[AA].op.axis[0], factor=32)
+ s[AA].bind(yi, tx)
+ s[AA].compute_at(s[B], xo)
+
+ ctx = tvm.gpu(0)
+ func = tvm.build(s, [A, B], "cuda")
+ A_np = np.random.uniform(size=(m,)).astype(A.dtype)
+ B_np = np.zeros(shape=(m,)).astype(B.dtype)
+ A_nd = tvm.nd.array(A_np, ctx)
+ B_nd = tvm.nd.array(B_np, ctx)
+ func(A_nd, B_nd)
+ B_np = A_np + 1
+ tvm.testing.assert_allclose(B_nd.asnumpy(), B_np)
+
+ check(m=31)
+ check(m=32)
+ check(m=33)
+
if __name__ == "__main__":
test_lower_warp_memory_local_scope()
test_lower_warp_memory_correct_indices()
test_lower_warp_memory_cuda_end_to_end()
test_lower_warp_memory_cuda_half_a_warp()
test_lower_warp_memory_cuda_2_buffers()
+ test_lower_warp_memory_roundup()
# under the License.
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
+from tvm import target as target_
from tvm import te
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing
-
def schedule_softmax(outs):
"""Schedule for softmax op.
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
softmax = outs[0]
+ tgt = target_.Target.current(allow_none=False)
op_tag = softmax.op.tag
if op_tag == 'softmax_output':
raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
Got {0}'.format(op_tag))
+ # The nvptx backend only supports 32-bits warp shuffle instructions.
+ #
+ # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
+ def sched_warp_softmax():
+ if tgt.target_name == "nvptx":
+ return softmax.dtype == "float32" or softmax.dtype == "int32"
+ return True
+
if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
if exp is not None:
for op in ops:
s = schedule_injective_from_existing(s, op.output(0))
+
+ elif sched_warp_softmax():
+ # A warp of 32 threads performs a row reduction.
+ num_thread = tgt.thread_warp_size
+ block_x = te.thread_axis("blockIdx.x")
+ thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
+
+ # (4) softmax
+ xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
+ _, xii = s[softmax].split(xi, factor=4)
+ s[softmax].vectorize(xii)
+ s[softmax].bind(xo, thread_x)
+ s[softmax].bind(softmax.op.axis[0], block_x)
+
+ # (3) expsum
+ k = expsum.op.reduce_axis[0]
+ ko, _ = s[expsum].split(k, nparts=num_thread)
+ s[expsum].bind(ko, thread_x)
+ s[expsum].compute_at(s[softmax], xo)
+
+ # (2) exp
+ if exp is not None:
+ xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
+ _, xii = s[exp].split(xi, factor=4)
+ s[exp].vectorize(xii)
+ s[exp].bind(xo, thread_x)
+ s[exp].compute_at(s[expsum], expsum.op.axis[0])
+ s[exp].compute_at(s[softmax], softmax.op.axis[0])
+ s[exp].set_scope("warp")
+
+ # (1) max_elem
+ k = max_elem.op.reduce_axis[0]
+ ko, _ = s[max_elem].split(k, nparts=num_thread)
+ s[max_elem].bind(ko, thread_x)
+ if exp is not None:
+ s[max_elem].compute_at(s[exp], xo)
+ else:
+ s[max_elem].bind(ko, thread_x)
+ s[max_elem].bind(max_elem.op.axis[0], block_x)
+
else:
num_thread = 64
block_x = te.thread_axis("blockIdx.x")