[CODEGEN] Fix code generation bugs for C/CUDA & Improve VerifyGPUCode pass (#6041)
authorLianmin Zheng <lianminzheng@gmail.com>
Mon, 13 Jul 2020 17:46:27 +0000 (10:46 -0700)
committerGitHub <noreply@github.com>
Mon, 13 Jul 2020 17:46:27 +0000 (10:46 -0700)
src/target/source/codegen_c.cc
src/tir/analysis/verify_gpu_code.cc
tests/python/unittest/test_target_codegen_c_host.py
tests/python/unittest/test_target_codegen_cuda.py

index 7c3c830..1530892 100644 (file)
@@ -629,12 +629,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
       this->PrintExpr(op->args[0], os);
       os << " == NULL)";
     } else if (op->op.same_as(builtin::reinterpret())) {
-      // generate (*( TYPE *)(&(ARG)))
+      int ssa_scope = BeginScope();
+      std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
       os << "(*(";
       this->PrintType(op->dtype, os);
-      os << " *)(&(";
-      this->PrintExpr(op->args[0], os);
-      os << ")))";
+      os << " *)(&(" << rhs << ")))";
+      EndScope(ssa_scope);
     } else if (op->op.same_as(builtin::isnan())) {
       os << "(";
       this->PrintExpr(op->args[0], os);
@@ -720,14 +720,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
   } else {
     CHECK(is_one(op->predicate)) << "Predicated store is not supported";
     arith::PVar<PrimExpr> base;
+
+    // The assignment below introduces side-effect, and the resulting value cannot
+    // be reused across multiple expression, thus a new scope is needed
+    int vec_scope = BeginScope();
+
     if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
       std::string value = this->PrintExpr(op->value);
       this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
     } else {
-      // The assignment below introduces side-effect, and the resulting value cannot
-      // be reused across multiple expression, thus a new scope is needed
-      int vec_scope = BeginScope();
-
       // store elements seperately
       std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
       std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
@@ -754,8 +755,8 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
         PrintVecElemLoad(value, op->value.dtype(), i, stream);
         stream << ";\n";
       }
-      EndScope(vec_scope);
     }
+    EndScope(vec_scope);
   }
 }
 
index 9477e04..d221dde 100644 (file)
@@ -145,6 +145,18 @@ class GPUCodeVerifier : public StmtExprVisitor {
     ExprVisitor::VisitExpr_(op);
   }
 
+  void VisitStmt_(const StoreNode* op) {
+    // Currently not able to check out: If the index expression failed
+    // to be simplified to a RampNode
+    if (op->index->IsInstance<RampNode>()) {
+      if (op->index->dtype.lanes() > 1) {
+        valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
+                  max_vector_bytes_;
+      }
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
  private:
   int nest_level_{0};
 
index 698dd74..31353ef 100644 (file)
@@ -98,7 +98,7 @@ def test_reinterpret():
     nn = 1024
     n = tvm.runtime.convert(nn)
     A = te.placeholder((n,), name='A', dtype="int32")
-    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
+    B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name='B')
     s = te.create_schedule(B.op)
 
     def check_c():
@@ -114,7 +114,7 @@ def test_reinterpret():
         b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
         fadd(a, b)
         tvm.testing.assert_allclose(
-            b.asnumpy(), a.asnumpy().view('float32'))
+            b.asnumpy(), (2 + a.asnumpy()).view('float32'))
     check_c()
 
 
index c977334..4cd08d0 100644 (file)
@@ -874,6 +874,46 @@ def test_vectorized_cooperative_fetching_xy():
 
     vcf_check_common(s, [A, B, C])
 
+def test_unrolled_vectorization():
+    if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    dtype = 'float32'
+    target = 'cuda'
+    
+    ## Compute declaration
+    N = 128
+    A = te.placeholder((N, N), name='A')
+    B = te.placeholder((N, N), name='B')
+    k = te.reduce_axis((0, N), name='k')
+    C = te.compute((N, N), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+    
+    ## Schedule
+    s = te.create_schedule([C.op])
+    CC = s.cache_write(C, "local")
+    i, j = s[C].op.axis
+    bx, tx, ii, ji = s[C].tile(i, j, 1, 2)
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].vectorize(ji)
+    s[CC].compute_at(s[C], tx)
+    i, j = s[CC].op.axis
+    k = s[CC].op.reduce_axis[0]
+    ko, ki = s[CC].split(k, 2)
+    s[CC].unroll(ki)
+    s[CC].vectorize(j)
+    
+    ## Check correctness
+    ctx = tvm.context(target)
+    a_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
+    b_tvm = tvm.nd.array(np.ones((N, N)).astype(dtype), ctx=ctx)
+    c_tvm = tvm.nd.empty((N, N), ctx=ctx)
+    func_tvm = tvm.build(s, [A, B, C], target=target)
+    func_tvm(a_tvm, b_tvm, c_tvm)
+    c_np = c_tvm.asnumpy()
+    tvm.testing.assert_allclose(c_np, N * np.ones((N, N)))
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
@@ -897,3 +937,5 @@ if __name__ == "__main__":
     test_cuda_vectorize_load_permute_pad()
     test_vectorized_cooperative_fetching_x()
     test_vectorized_cooperative_fetching_xy()
+    test_unrolled_vectorization()
+