[CodeGen][CUDA] Fix bugs (#5209)
authorWei Pan <60017475+wpan11nv@users.noreply.github.com>
Fri, 3 Apr 2020 06:57:40 +0000 (23:57 -0700)
committerGitHub <noreply@github.com>
Fri, 3 Apr 2020 06:57:40 +0000 (02:57 -0400)
- Support vectorized casts

- It is incorrect to extract elements from int8x4 with

   0x000000ff & (x >> i * 8)

  as this value is of type int in C/C++. If this expression
  is used for sign extensions, the sign bit will be wrong.
  Simply use C style casts instead and sign bits will just work.

Signed-off-by: Wei Pan <weip@nvidia.com>
src/target/source/codegen_cuda.cc
src/target/source/codegen_cuda.h
tests/python/unittest/test_target_codegen_cuda.py

index f8bc873..9c4fc69 100644 (file)
@@ -273,8 +273,10 @@ void CodeGenCUDA::PrintVecElemLoad(
     const std::string& vec, DataType t, int i, std::ostream& os) {  // NOLINT(*)
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
-  if (t.is_int() && t.bits() == 8) {
-    os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))";
+  if ((t.is_int()) && t.bits() == 8) {
+    os << "((char)(" << vec << " >> " << i * 8 << "))";
+  } else if ((t.is_uint()) && t.bits() == 8) {
+    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
   } else if (t.is_float16()) {
     os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
        << access[i % 2];
@@ -288,7 +290,7 @@ void CodeGenCUDA::PrintVecElemStore(
   this->PrintIndent();
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
-  if (t.is_int() && t.bits() == 8) {
+  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
     stream << vec << "=";
     // Do not read the first undef lane.
     if (i != 0) {
@@ -352,6 +354,37 @@ void CodeGenCUDA::PrintStorageScope(
   }
 }
 
+void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
+  DataType from_ty = op->value.dtype();
+  DataType target_ty = op->dtype;
+  CHECK_EQ(target_ty.lanes(), from_ty.lanes());
+
+  // Emit simple C-style type conversion.
+  if (from_ty.is_scalar())
+    return CodeGenC::VisitExpr_(op, os);
+
+  // We could emit make_float4 like calls, but the emitted code looks
+  // too compact to read. Emit this as vectorized unary ops.
+  std::string sret = GetUniqueName("_");
+  this->PrintIndent();
+  this->PrintType(target_ty, stream);
+  stream << ' ' << sret << ";\n";
+  {
+    EnterScopeRAII scope(this);
+    std::string src = SSAGetID(PrintExpr(op->value), from_ty);
+    for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
+      std::ostringstream val;
+      val << "(";
+      PrintType(target_ty.element_of(), val);
+      val << ")(";
+      PrintVecElemLoad(src, from_ty, i, val);
+      val << ")";
+      PrintVecElemStore(sret, target_ty, i, val.str());
+    }
+  }
+  os << sret;
+}
+
 void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
   if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
     need_mma_h_ = true;
index c31bdf5..6ba7487 100644 (file)
@@ -62,6 +62,7 @@ class CodeGenCUDA final : public CodeGenC {
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
   void VisitExpr_(const CallNode *op, std::ostream& os) final;
+  void VisitExpr_(const CastNode* op, std::ostream& os) final;
   void VisitStmt_(const EvaluateNode *op) final;
   void VisitStmt_(const AllocateNode *op) final;
   void VisitStmt_(const AttrStmtNode *op) final;
index e8c6cd1..75d6c14 100644 (file)
@@ -348,6 +348,55 @@ def test_cuda_floordiv_with_vectorization():
         func(a_nd, b_nd)
         tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
 
+def test_vectorized_casts():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    def check(t0, t1):
+        if (t0 ==  "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        # compute
+        n = 128
+        A = te.placeholder((n,), dtype=t0, name='A')
+        B = te.placeholder((n,), dtype=t1, name='B')
+        C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name='C')
+
+        # schedule
+        s = tvm.te.create_schedule(C.op)
+        ob, ib = s[C].split(s[C].op.axis[0], nparts=32)
+        _, iib = s[C].split(ib, factor=4)
+        s[C].vectorize(iib)
+        s[C].bind(ob, tx)
+        func = tvm.build(s, [A, B, C], "cuda")
+
+        # correctness
+        ctx = tvm.gpu(0)
+        low, high = (0, 20) if t0.startswith('u') or t1.startswith('u') else (-10, 10)
+        a_np = np.random.randint(low, high, size=n).astype(A.dtype)
+        b_np = np.random.randint(low, high, size=n).astype(B.dtype)
+        c_np = (a_np + b_np).astype(A.dtype)
+        a_nd = tvm.nd.array(a_np, ctx)
+        b_nd = tvm.nd.array(b_np, ctx)
+        c_nd = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np.dtype), ctx)
+        func(a_nd, b_nd, c_nd)
+        tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=1e-3)
+
+    def skip(t0, t1):
+        if t0 == t1:
+            return True
+        # CUDA does support cast between {u}int8 and fp16.
+        skip_set = {"float16", "uint8", "int8"}
+        if t0 in skip_set and t1 in skip_set:
+            return True
+        return False
+
+    types = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
+    for t0, t1 in [(x, y) for x in types for y in types if not skip(x, y)]:
+        check(t0, t1)
+
 def sched(B):
     s = te.create_schedule(B.op)
     io, ii = s[B].split(s[B].op.axis[0], nparts=1)
@@ -474,6 +523,7 @@ if __name__ == "__main__":
     test_cuda_make_int8x4()
     test_cuda_inf_nan()
     test_cuda_shuffle()
+    test_vectorized_casts()
     test_cuda_reducition_binding()
     test_rfactor_predicates()
     test_cuda_const_float_to_half()