[CodeGen][CUDA] Enhance CUDA codegen for SelectNode (#4983)
authorWei Pan <60017475+wpan11nv@users.noreply.github.com>
Wed, 11 Mar 2020 02:16:59 +0000 (19:16 -0700)
committerGitHub <noreply@github.com>
Wed, 11 Mar 2020 02:16:59 +0000 (11:16 +0900)
- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <weip@nvidia.com>
include/tvm/runtime/data_type.h
src/target/source/codegen_cuda.cc
src/target/source/codegen_cuda.h
tests/python/unittest/test_codegen_cuda.py

index e6f5e55..6379a13 100644 (file)
@@ -112,6 +112,10 @@ class DataType {
   bool is_vector() const {
     return lanes() > 1;
   }
+  /*! \return whether type is a bool vector type. */
+  bool is_vector_bool() const {
+    return is_vector() && bits() == 1;
+  }
   /*!
    * \brief Create a new data type by change lanes to a specified value.
    * \param lanes The target number of lanes.
index d5cab6e..24f655b 100644 (file)
@@ -135,6 +135,13 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
     }
   } else if (t == DataType::Bool()) {
     os << "bool"; return;
+  } else if (t.is_vector_bool()) {
+    // CUDA does not support bool vectors.
+    // Use ushort vectors to represent instead.
+    int n = t.lanes();
+    if (n <= 4) {
+      os << "ushort" << n; return;
+    }
   } else if (t.is_uint() || t.is_int()) {
     if (t.is_uint()) {
       if (t.lanes() != 1) {
@@ -226,7 +233,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
 }
 
 void CodeGenCUDA::PrintVecBinaryOp(
-    const std::string&op, DataType t,
+    const std::string& op, DataType t,
     PrimExpr lhs, PrimExpr rhs, std::ostream& os) {  // NOLINT(*)
   // unpacking operations.
   int lanes = t.lanes();
@@ -561,6 +568,48 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) {
   os << ')';
 }
 
+void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
+  // Non-vector cases.
+  if (!op->dtype.is_vector()) {
+    CodeGenC::VisitExpr_(op, os);
+    return;
+  }
+
+  // Codegen vector condition case by serializing the select op.
+  CHECK(op->false_value->dtype == op->dtype &&
+        op->true_value->dtype == op->dtype &&
+        op->dtype.lanes() == op->condition.dtype().lanes());
+
+  int lanes = op->dtype.lanes();
+  int scope = BeginScope();
+
+  std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
+  std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
+  std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
+  std::string r_var = GetUniqueName("_");
+
+  this->PrintIndent();
+  this->PrintType(op->dtype, stream);
+  stream << ' ' << r_var << ";\n";
+
+  // The condition is stored as an ushort vector.
+  DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);
+
+  for (int i = 0; i < lanes; ++i) {
+    std::ostringstream item;
+    item << "(bool(";
+    PrintVecElemLoad(c_var, memory_ty, i, item);
+    item << ")?";
+    PrintVecElemLoad(t_var, op->dtype, i, item);
+    item << ':';
+    PrintVecElemLoad(f_var, op->dtype, i, item);
+    item << ')';
+    PrintVecElemStore(r_var, op->dtype, i, item.str());
+  }
+  os << r_var;
+  EndScope(scope);
+}
+
 inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
   switch (op->dtype.bits()) {
     case 64: case 32: {
index d0a98a6..a634c10 100644 (file)
@@ -43,11 +43,11 @@ class CodeGenCUDA final : public CodeGenC {
     return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
-  void VisitStmt_(const tir::ForNode* op) final;
+  void VisitStmt_(const ForNode* op) final;
   void PrintStorageSync(const CallNode* op) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  // NOLINT(*)
   void PrintVecBinaryOp(
-      const std::string&op, DataType t,
+      const std::string& op, DataType t,
       PrimExpr lhs, PrimExpr rhs, std::ostream& os) final;  // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
   void PrintVecElemLoad(
@@ -58,6 +58,7 @@ class CodeGenCUDA final : public CodeGenC {
   // overload visitor
   void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const FloatImmNode *op, std::ostream& os) final;
   void VisitExpr_(const CallNode *op, std::ostream& os) final;
index f94d8c3..083cede 100644 (file)
@@ -321,6 +321,33 @@ def test_cuda_reduction():
     check_cuda("float32")
     check_cuda("float16")
 
+def test_cuda_floordiv_with_vectorization():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    with tvm.target.cuda():
+        # B[i] = A[floordiv(i, k)]
+        n = 256
+        k = 37
+        A = te.placeholder((n,), name='A')
+        B = te.compute((n,), lambda i: A[tvm.tir.floordiv(i, k)], name='B')
+        s = te.create_schedule(B.op)
+        xo, xi = s[B].split(B.op.axis[0], nparts=1)
+        xio, xii = s[B].split(xi, factor=4)
+        s[B].vectorize(xii)
+        s[B].bind(xo, bx)
+        s[B].bind(xio, tx)
+        func = tvm.build(s, [A, B], 'cuda')
+
+        ctx = tvm.gpu(0)
+        a_np = np.random.uniform(size=(n,)).astype(A.dtype)
+        b_np = np.array([a_np[i//k] for i in range(0, n)])
+        a_nd = tvm.nd.array(a_np, ctx)
+        b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
+        func(a_nd, b_nd)
+        tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
@@ -331,4 +358,5 @@ if __name__ == "__main__":
     test_cuda_reducition_binding()
     test_rfactor_predicates()
     test_cuda_const_float_to_half()
-    test_cuda_reduction()
\ No newline at end of file
+    test_cuda_reduction()
+    test_cuda_floordiv_with_vectorization()
\ No newline at end of file