[CODEGEN][CUDA] Fix vector load (#5226)
authorLiangLiu <1432249204@qq.com>
Tue, 14 Apr 2020 03:35:31 +0000 (11:35 +0800)
committerGitHub <noreply@github.com>
Tue, 14 Apr 2020 03:35:31 +0000 (23:35 -0400)
* Fix high-low bit bug in __pack_half2

* Fix vector load

* Add unit8 support for PrintVecElemLoadExpr and BroadcastNode

src/target/source/codegen_c.cc
src/target/source/codegen_c.h
src/target/source/codegen_cuda.cc
src/target/source/codegen_cuda.h
src/target/source/literal/cuda_half_t.h
tests/python/unittest/test_target_codegen_cuda.py

index 444dc99..ac13f8a 100644 (file)
@@ -668,15 +668,7 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
       std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
       HandleVolatileLoads(ref, op, os);
     } else {
-      // The assignment below introduces side-effect, and the resulting value cannot
-      // be reused across multiple expression, thus a new scope is needed
-      int vec_scope = BeginScope();
-
-      // load seperately.
-      std::string svalue = GetUniqueName("_");
-      this->PrintIndent();
-      this->PrintType(op->dtype, stream);
-      stream << ' ' << svalue << ";\n";
+      std::ostringstream svalue_expr;
       std::string sindex = SSAGetID(PrintExpr(op->index), op->index.dtype());
       std::string vid = GetVarID(op->buffer_var.get());
       DataType elem_type = op->dtype.element_of();
@@ -699,10 +691,9 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) {  // NOLINT(*)
         value_temp << '[';
         PrintVecElemLoad(sindex, op->index.dtype(), i, value_temp);
         value_temp << ']';
-        PrintVecElemStore(svalue, op->dtype, i, value_temp.str());
+        PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
       }
-      os << svalue;
-      EndScope(vec_scope);
+      os << svalue_expr.str();
     }
   }
 }
@@ -955,5 +946,30 @@ void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) {
   PrintStmt(op->body);
 }
 
+void CodeGenC::PrintVecElemLoadExpr(
+    DataType t, int i, const std::string& value, std::ostream& os) {
+  CHECK_GT(t.lanes(), 1);
+  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
+    if (i != 0) {
+      os << "|";
+    }
+    os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
+    return;
+  }
+
+  if (i == 0) {
+    os << "((";
+    PrintType(t, os);
+    os << t.lanes() << ")(";
+  }
+  os << value;
+  if (i != t.lanes() - 1) {
+    os << ",";
+  } else {
+    os << "))";
+  }
+  return;
+}
+
 }  // namespace codegen
 }  // namespace tvm
index 30ad890..49139de 100644 (file)
@@ -191,6 +191,8 @@ class CodeGenC :
       const std::string& vec, DataType t, int i, const std::string& value);
   // Get a cast type from to
   virtual std::string CastFromTo(std::string value, DataType from, DataType target);
+  // Get load of single element with expression
+  virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os);
 
  protected:
   // Print reference to struct location
index 9c4fc69..c7971ce 100644 (file)
@@ -591,13 +591,17 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
 }
 
 void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) {   // NOLINT(*)
-  if (op->dtype.is_int() && op->dtype.bits() == 8 && op->lanes == 4) {
+  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
     // make_int8x4
     const int64_t *p = as_const_int(op->value);
     CHECK(p);
     int64_t v = *p & 0xFF;
     v = (v << 24) | (v << 16) | (v << 8) | v;
-    os << "(int)" << v;
+    if (op->dtype.is_uint()) {
+      os << "(uint)" << v;
+    } else {
+      os << "(int)" << v;
+    }
     return;
   }
 
@@ -796,5 +800,49 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value,
   }
 }
 
+void CodeGenCUDA::PrintVecElemLoadExpr(
+    DataType t, int i, const std::string& value, std::ostream& os) {
+  CHECK_GT(t.lanes(), 1);
+  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
+    if (i != 0) {
+      os << "|";
+    }
+    os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
+    return;
+  }
+
+  if (t.is_float16()) {
+    if (i == 0) {
+      os << "make_";
+      PrintType(t, os);
+      os << '(';
+    }
+    if (i % 2 == 0) {
+      os << "__pack_half2(" << value;
+    } else {
+      os << "," << value << ")";
+      if (i != t.lanes() - 1) {
+        os << ",";
+      } else {
+        os << ")";
+      }
+    }
+    return;
+  }
+
+  if (i == 0) {
+    os << "make_";
+    PrintType(t, os);
+    os << "(";
+  }
+  os << value;
+  if (i != t.lanes() - 1) {
+    os << ",";
+  } else {
+    os << ")";
+  }
+  return;
+}
+
 }  // namespace codegen
 }  // namespace tvm
index 6ba7487..d1db704 100644 (file)
@@ -55,6 +55,7 @@ class CodeGenCUDA final : public CodeGenC {
   void PrintVecElemStore(
       const std::string& vec, DataType t, int i, const std::string& value) final;
   void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
+  void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
   // overload visitor
   void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
index fd0652a..858ac85 100644 (file)
@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned
 __pack_half2(const half x, const half y) {
   unsigned v0 = *((unsigned short *)&x);
   unsigned v1 = *((unsigned short *)&y);
-  return (v0 << 16) | v1;
+  return (v1 << 16) | v0;
 }
 )";
 
index bb162f4..739fc6f 100644 (file)
@@ -543,6 +543,44 @@ def test_vectorized_popcount():
     run_test("uint32")
     run_test("uint64")
 
+def test_cuda_vectorize_load_permute_pad():
+    def check_cuda(dtype, n, l, padding, lanes):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        ctx = tvm.gpu(0)
+        A = tvm.te.placeholder((n, l), name='A', dtype=dtype)
+        B = tvm.te.compute((n // lanes, l + 2 * padding, lanes),
+                           lambda i, j, k: tvm.te.if_then_else(
+            tvm.te.any(j < padding, j >= l + padding),
+            tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding]),
+            name='B')
+        s = te.create_schedule(B.op)
+        block, thread, vectorize = s[B].op.axis
+        s[B].bind(block, bx)
+        s[B].bind(thread, tx)
+        s[B].vectorize(vectorize)
+        fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad")
+        np_a = np.random.randint(
+            low=-128, high=127, size=(n, l)).astype(A.dtype)
+        a = tvm.nd.empty((n, l), A.dtype, ctx).copyfrom(np_a)
+        b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, ctx)
+        fun(a, b)
+        np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1)
+        ref = np.pad(np_a_reshape, ((0, 0), (padding, padding),
+                                    (0, 0)), mode='constant', constant_values=0)
+        tvm.testing.assert_allclose(b.asnumpy(), ref)
+
+    check_cuda("int8", 64, 16, 3, 4)
+    check_cuda("uint8", 64, 16, 3, 4)
+    check_cuda("int32", 64, 16, 3, 4)
+    check_cuda("float16", 64, 16, 3, 4)
+    check_cuda("float32", 64, 16, 3, 4)
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
@@ -560,3 +598,4 @@ if __name__ == "__main__":
     test_vectorized_intrin1()
     test_vectorized_intrin2()
     test_vectorized_popcount()
+    test_cuda_vectorize_load_permute_pad()