[CodeGen][CUDA] Improve CUDA vectorizer (#4736)
authorwpan11nv <60017475+wpan11nv@users.noreply.github.com>
Sat, 18 Jan 2020 02:58:11 +0000 (18:58 -0800)
committerWuwei Lin <wuwei@apache.org>
Sat, 18 Jan 2020 02:58:11 +0000 (21:58 -0500)
- Fixes issues to enable fp16 vectorizer. Now correct packing and
  unpacking CUDA code will be emitted. Enabled more unit tests.

- Do not emit code to read the first lane from an undef variable

  int _3;
  _3 = _3 & ~(0x000000ff << 0) | ...

  and emit the following code instead:

  _3 = (((0x000000ff & (_1 >> 0))+(0x000000ff & (_2 >> 0))) << 0);

  Note that nvcc 10.2 is forgiving and emits the same code for both cases.
  A warning appears in test_codegen_cuda.py.

Signed-off-by: Wei Pan <weip@nvidia.com>
include/tvm/runtime/data_type.h
src/codegen/codegen_cuda.cc
src/codegen/literal/cuda_half_t.h
tests/python/unittest/test_codegen_cuda.py

index cb58e97..7e0ef49 100644 (file)
@@ -92,6 +92,10 @@ class DataType {
   bool is_float() const {
     return code() == DataType::kFloat;
   }
+  /*! \return whether type is a float16 type. */
+  bool is_float16() const {
+    return is_float() && bits() == 16;
+  }
   /*! \return whether type is an int type. */
   bool is_int() const {
     return code() == DataType::kInt;
index 6f394a1..b6ba17f 100644 (file)
@@ -73,6 +73,7 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#else\n";
     decl_stream << _cuda_half_t_def;
     decl_stream << "#endif\n\n";
+    decl_stream << _cuda_half_util;
   }
 
   if (enable_int8_) {
@@ -122,8 +123,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
         if (lanes == 1) {
           os << "half";
         } else if (lanes <= 8) {
+          // Emit CUDA code to access fp16 vector elements.
+          //
+          // half4 is stored as uint2
+          //
+          // h4.x is emitted as *(half2*)(&(u2.x)).x
+          // h4.y is emitted as *(half2*)(&(u2.x)).y
+          // h4.z is emitted as *(half2*)(&(u2.y)).x
+          // h4.w is emitted as *(half2*)(&(u2.y)).y
+          //
           CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
-          os << "float" << lanes / 2;
+          os << "uint" << lanes / 2;
         } else {
           fail = true;
         }
@@ -243,9 +253,12 @@ void CodeGenCUDA::PrintVecBinaryOp(
 void CodeGenCUDA::PrintVecElemLoad(
     const std::string& vec, DataType t, int i, std::ostream& os) {  // NOLINT(*)
   static const char access[] = {'x', 'y', 'z', 'w'};
-  CHECK(i >= 0 && i < 4);
+  CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if (t.is_int() && t.bits() == 8) {
     os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))";
+  } else if (t.is_float16()) {
+    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
+       << access[i % 2];
   } else {
     os << vec << "." << access[i];
   }
@@ -255,10 +268,17 @@ void CodeGenCUDA::PrintVecElemStore(
     const std::string& vec, DataType t, int i, const std::string& value) {
   this->PrintIndent();
   static const char access[] = {'x', 'y', 'z', 'w'};
-  CHECK(i >= 0 && i < 4);
+  CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if (t.is_int() && t.bits() == 8) {
-    stream << vec << "=" << vec << " & ~(0x000000ff << " << i * 8 << ") | ("
-        << value << " << " << i * 8 << ");\n";
+    stream << vec << "=";
+    // Do not read the first undef lane.
+    if (i != 0) {
+      stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
+    }
+    stream << "(" << value << " << " << i * 8 << ");\n";
+  } else if (t.is_float16()) {
+    stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
+           << access[i % 2] << " = " << value << ";\n";
   } else {
     stream << vec << "." << access[i] << " = " << value << ";\n";
   }
@@ -462,6 +482,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // N
     return;
   }
 
+  if (op->dtype.is_float16()) {
+    std::string v = PrintExpr(op->value);
+    os << "make_";
+    PrintType(op->dtype, os);
+    os << '(';
+    for (int i = 0; i < op->lanes / 2; ++i) {
+      if (i != 0) os << ", ";
+      os << "__pack_half2(" << v << ", " << v <<  ")";
+    }
+    os << ')';
+    return;
+  }
+
   std::string v = PrintExpr(op->value);
   os << "make_";
   PrintType(op->dtype, os);
index 630a741..7e9c72e 100644 (file)
@@ -285,4 +285,14 @@ TVM_XINLINE half __float2half_rn(const float a) {
 }
 )";
 
+static constexpr const char* _cuda_half_util = R"(
+// Pack two half values.
+static inline __device__ __host__ unsigned
+__pack_half2(const half x, const half y) {
+  unsigned v0 = *((unsigned short *)&x);
+  unsigned v1 = *((unsigned short *)&y);
+  return (v0 << 16) | v1;
+}
+)";
+
 #endif  // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
index 27a8d87..5d05b6d 100644 (file)
 import tvm
 import numpy as np
 import unittest
-from tvm.contrib.nvcc import have_fp16, have_int8
+from tvm.contrib.nvcc import parse_compute_version, have_int8
 from tvm.contrib import nvcc
 
 tx = tvm.thread_axis("threadIdx.x")
 bx = tvm.thread_axis("blockIdx.x")
 
-
 def test_cuda_vectorize_add():
     num_thread = 8
     def check_cuda(dtype, n, lanes):
         if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
             print("skip because cuda is not enabled..")
             return
-        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
-            print("skip because gpu does not support fp16")
+        if dtype == "float16":
+            major, minor = parse_compute_version(tvm.gpu(0).compute_version)
+            # fp16 starts from 5.3
+            if major < 6 or (major == 5 and minor < 3):
+                print("skip because gpu does not support fp16")
             return
         if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
             print("skip because gpu does not support int8")
@@ -52,13 +54,13 @@ def test_cuda_vectorize_add():
         tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
 
     check_cuda("float32", 64, 2)
-    check_cuda("int8", 64, 4)
-    # check_cuda("float16", 64, 2)
-
-    # TODO(tvm-team) fix fp16 codegen here
-    # or hit an error if it is less frequently used.
-    # check_cuda("float16", 64, 2)
-
+    check_cuda("float32", 64, 3)
+    check_cuda("float32", 64, 4)
+    check_cuda("int8",    64, 4)
+    check_cuda("float16", 64, 2)
+    check_cuda("float16", 64, 4)
+    check_cuda("float16", 64, 6)
+    check_cuda("float16", 64, 8)
 
 def test_cuda_multiply_add():
     num_thread = 8