Add shuffle support to TVM (#3633)
authorJian Weng <jian.weng@ucla.edu>
Thu, 1 Aug 2019 19:52:33 +0000 (12:52 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 1 Aug 2019 19:52:33 +0000 (12:52 -0700)
12 files changed:
include/tvm/ir_functor_ext.h
include/tvm/ir_visitor.h
src/codegen/build_common.h
src/codegen/codegen_c.cc
src/codegen/codegen_c.h
src/codegen/codegen_cuda.cc
src/codegen/codegen_cuda.h
src/codegen/llvm/codegen_llvm.cc
src/codegen/llvm/codegen_llvm.h
src/pass/ir_visitor.cc
tests/python/unittest/test_codegen_cuda.py
tests/python/unittest/test_codegen_llvm.py

index d5b27a6..a7d91ea 100644 (file)
@@ -199,6 +199,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(Not);
     IR_EXPR_FUNCTOR_DISPATCH(Select);
     IR_EXPR_FUNCTOR_DISPATCH(Ramp);
+    IR_EXPR_FUNCTOR_DISPATCH(Shuffle);
     IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
     IR_EXPR_FUNCTOR_DISPATCH(IntImm);
     IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
index c36f631..f20b913 100644 (file)
@@ -131,6 +131,7 @@ class TVM_DLL IRVisitor {
   virtual void Visit_(const Not* op);
   virtual void Visit_(const Select* op);
   virtual void Visit_(const Ramp* op);
+  virtual void Visit_(const Shuffle* op);
   virtual void Visit_(const Broadcast* op);
   virtual void Visit_(const AssertStmt* op);
   virtual void Visit_(const ProducerConsumer* op);
index 713922b..0bb4002 100644 (file)
@@ -26,6 +26,7 @@
 #define TVM_CODEGEN_BUILD_COMMON_H_
 
 #include <tvm/codegen.h>
+#include <tvm/ir.h>
 #include <unordered_map>
 #include <string>
 #include "../runtime/meta_data.h"
index 395d3f3..043c647 100644 (file)
@@ -728,6 +728,10 @@ void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) {  // NOLINT(*)
   os << "))";
 }
 
+void CodeGenC::VisitExpr_(const Shuffle* op, std::ostream& os) {
+  LOG(FATAL) << "Shuffle: not supported ";
+}
+
 void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLINT(*)
   LOG(FATAL) << "Broadcast: not supported ";
 }
index 5e84cd9..5cd30f1 100644 (file)
@@ -126,6 +126,7 @@ class CodeGenC :
   void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
+  void VisitExpr_(const Shuffle* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
   void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
index a324731..d13b2c9 100644 (file)
@@ -205,7 +205,7 @@ void CodeGenCUDA::PrintVecBinaryOp(
 
 void CodeGenCUDA::PrintVecElemLoad(
     const std::string& vec, Type t, int i, std::ostream& os) {  // NOLINT(*)
-  const char access[] = {'x', 'y', 'z', 'w'};
+  static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < 4);
   os << vec << "." << access[i];
 }
@@ -213,7 +213,7 @@ void CodeGenCUDA::PrintVecElemLoad(
 void CodeGenCUDA::PrintVecElemStore(
     const std::string& vec, Type t, int i, const std::string& value) {
   this->PrintIndent();
-  const char access[] = {'x', 'y', 'z', 'w'};
+  static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < 4);
   stream << vec << "." << access[i] << " = " << value << ";\n";
 }
@@ -308,7 +308,7 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLIN
   std::string v = PrintExpr(op->value);
   os << "make_";
   PrintType(op->type, os);
-  os << "(";
+  os << '(';
   for (int i = 0; i < op->lanes; ++i) {
     if (i != 0) os << ", ";
     os << v;
@@ -316,6 +316,23 @@ void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) {   // NOLIN
   os << ')';
 }
 
+void CodeGenCUDA::VisitExpr_(const Shuffle* op, std::ostream &os) {
+  std::vector<std::string> to_shuffle(op->vectors.size());
+  for (int i = 0, e = op->vectors.size(); i < e; ++i) {
+    CHECK(op->vectors[i].type().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
+    to_shuffle[i] = PrintExpr(op->vectors[i]);
+  }
+  os << "make_";
+  PrintType(op->type, os);
+  os << '(';
+  for (int i = 0, e = op->indices.size(); i < e; ++i) {
+    const int64_t *val = as_const_int(op->indices[i]);
+    CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size());
+    if (i != 0) os << ", ";
+    os << to_shuffle[*val];
+  }
+  os << ')';
+}
 
 inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
   switch (op->type.bits()) {
index acd759f..61c6fa3 100644 (file)
@@ -57,6 +57,7 @@ class CodeGenCUDA final : public CodeGenC {
   void BindThreadIndex(const IterVar& iv) final;  // NOLINT(*)
   // overload visitor
   void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
+  void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
   void VisitExpr_(const FloatImm *op, std::ostream& os) final;
   void VisitStmt_(const Evaluate *op) final;
index 537de1e..5bc415f 100644 (file)
@@ -30,6 +30,7 @@
 
 #include "codegen_llvm.h"
 #include "codegen_cpu.h"
+#include "../build_common.h"
 #include "../../pass/ir_util.h"
 #include "../../arithmetic/compute_expr.h"
 
@@ -446,6 +447,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
   int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
   if (extent == num_elems && begin == 0) return vec;
+  CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
   std::vector<llvm::Constant*> indices;
   indices.reserve(extent);
   for (int i = 0; i < extent; ++i) {
@@ -481,6 +483,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
 llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
   // concat vector, tree shape reduction
   int total_lanes = 0;
+
   for (llvm::Value* v : vecs) {
     total_lanes += static_cast<int>(
         v->getType()->getVectorNumElements());
@@ -652,12 +655,14 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
     CHECK_GE(op->args.size(), 2U);
     llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
         op->args[0].as<UIntImm>()->value);
-    uint64_t num_signature = op->args[1].as<UIntImm>()->value;
+    const uint64_t *num_signature = as_const_uint(op->args[1]);
+    CHECK(num_signature) << "The second argument should be a uint represents number of arguments, "
+                         << "but " << op->args[1] << " got!\n";
     std::vector<llvm::Value*> arg_value;
     std::vector<llvm::Type*> sig_type;
     for (size_t i = 2; i < op->args.size(); ++i) {
       arg_value.push_back(MakeValue(op->args[i]));
-      if (i - 2 < num_signature) {
+      if (i - 2 < *num_signature) {
         sig_type.push_back(arg_value.back()->getType());
       }
     }
@@ -1002,6 +1007,26 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
   return vec;
 }
 
+llvm::Value* CodeGenLLVM::VisitExpr_(const Shuffle* op) {
+  std::vector<llvm::Value *> vecs(op->vectors.size());
+  int total_lanes = 0;
+  for (int i = 0, e = op->vectors.size(); i < e; ++i) {
+    vecs[i] = VisitExpr(op->vectors[i]);
+    total_lanes += op->vectors[i].type().lanes();
+  }
+  llvm::Value* v0 = CreateVecConcat(vecs);
+  std::vector<uint32_t> idx(op->indices.size());
+  for (int i = 0, e = op->indices.size(); i < e; ++i) {
+    const int64_t *val = as_const_int(op->indices[i]);
+    CHECK(val && *val >= 0 && *val  < total_lanes) << "Shuffled indeces are suppose to be int, "
+      << "but get " << op->indices[i] << "\n";
+    idx[i] = *val;
+  }
+  llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx);
+  auto res = builder_->CreateShuffleVector(v0, llvm::UndefValue::get(v0->getType()), mask);
+  return res;
+}
+
 llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) {
   return CreateBroadcast(MakeValue(op->value), op->lanes);
 }
index 6be860a..13dc6ac 100644 (file)
@@ -131,6 +131,7 @@ class CodeGenLLVM :
   llvm::Value* VisitExpr_(const Load* op) override;
   llvm::Value* VisitExpr_(const Call* op) override;
   llvm::Value* VisitExpr_(const Ramp* op) override;
+  llvm::Value* VisitExpr_(const Shuffle* op) override;
   llvm::Value* VisitExpr_(const Broadcast* op) override;
   // stmt
   void VisitStmt_(const Store* op) override;
index dd469f4..fde183e 100644 (file)
@@ -177,6 +177,13 @@ void IRVisitor::Visit_(const Ramp *op) {
   this->Visit(op->stride);
 }
 
+void IRVisitor::Visit_(const Shuffle *op) {
+  for (const auto &elem : op->indices)
+    this->Visit(elem);
+  for (const auto &elem : op->vectors)
+    this->Visit(elem);
+}
+
 void IRVisitor::Visit_(const Broadcast *op) {
   this->Visit(op->value);
 }
@@ -269,6 +276,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
 .DISPATCH_TO_VISIT(Not)
 .DISPATCH_TO_VISIT(Select)
 .DISPATCH_TO_VISIT(Ramp)
+.DISPATCH_TO_VISIT(Shuffle)
 .DISPATCH_TO_VISIT(Broadcast)
 .DISPATCH_TO_VISIT(AssertStmt)
 .DISPATCH_TO_VISIT(ProducerConsumer)
index 8fe6720..e8439de 100644 (file)
@@ -154,9 +154,53 @@ def test_cuda_inf_nan():
     check_inf_nan(ctx, 1, float('nan'), 'float64')
 
 
+def test_cuda_shuffle():
+    if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
+        print("skip because cuda is not enabled..")
+        return
+
+    a = tvm.placeholder((64, ), 'int32')
+    b = tvm.placeholder((64, ), 'int32')
+    c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
+    sch = tvm.create_schedule(c.op)
+    x = c.op.axis[0]
+    xo, xi = sch[c].split(x, 4)
+    thrx = tvm.thread_axis("threadIdx.x")
+    sch[c].bind(xo, thrx)
+    sch[c].vectorize(xi)
+
+    def my_vectorize(stmt):
+        def vectorizer(op):
+            if op.for_type == tvm.stmt.For.Vectorized:
+                four = tvm.const(4, 'int32')
+                idx = tvm.make.Ramp(thrx.var * four, tvm.const(1, 'int32'), 4)
+                all_ones = tvm.const(1, 'int32x4')
+                store = op.body
+                value = store.value
+                new_a = tvm.make.Load('int32x4', value.a.buffer_var, idx, all_ones)
+                bs, ids = [], []
+                for i in range(4):
+                    bs.append(tvm.make.Load('int32', value.b.buffer_var, thrx.var * four + tvm.const(i, 'int32')))
+                    ids.append(tvm.const(3 - i, 'int32'))
+                new_b = tvm.make.Shuffle(bs, ids)
+                return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
+            return None
+        return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
+
+    with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
+        module = tvm.build(sch, [a, b, c], target='cuda')
+        a_ = np.array(list(range(64)), dtype='int32')
+        b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
+        c_ = np.zeros((64, ), dtype='int32')
+        ref = a_ +  np.array((list(range(4))) * 16, dtype='int32')
+        nda, ndb, ndc = [tvm.ndarray.array(i, tvm.gpu(0)) for i in [a_, b_, c_]]
+        module(nda, ndb, ndc)
+        tvm.testing.assert_allclose(ndc.asnumpy(), ref)
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
     test_cuda_vectorize_load()
     test_cuda_make_int8x4()
     test_cuda_inf_nan()
+    test_cuda_shuffle()
index ed6fedc..34dad36 100644 (file)
@@ -548,6 +548,37 @@ def test_dwarf_debug_information():
     check_llvm_object()
     check_llvm_ir()
 
+
+def test_llvm_shuffle():
+    a = tvm.placeholder((8, ), 'int32')
+    b = tvm.placeholder((8, ), 'int32')
+    c = tvm.compute((8, ), lambda x: a[x] + b[7-x])
+    sch = tvm.create_schedule(c.op)
+
+    def my_vectorize(stmt):
+
+        def vectorizer(op):
+            store = op.body
+            idx = tvm.make.Ramp(tvm.const(0, 'int32'), tvm.const(1, 'int32'), 8)
+            all_ones = tvm.const(1, 'int32x8')
+            value = store.value
+            b_idx = tvm.make.Shuffle([idx], [tvm.const(i, 'int32') for i in range(7, -1, -1)])
+            new_a = tvm.make.Load('int32x8', value.a.buffer_var, idx, all_ones)
+            new_b = tvm.make.Load('int32x8', value.b.buffer_var, b_idx, all_ones)
+            value = new_a + new_b
+            return tvm.make.Store(store.buffer_var, new_a + new_b, idx, all_ones)
+
+        return tvm.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
+
+    with tvm.build_config(add_lower_pass=[(1, my_vectorize)]):
+        ir = tvm.lower(sch, [a, b, c], simple_mode=True)
+        module = tvm.build(sch, [a, b, c])
+        a_ = tvm.ndarray.array(np.arange(1, 9, dtype='int32'))
+        b_ = tvm.ndarray.array(np.arange(8, 0, -1, dtype='int32'))
+        c_ = tvm.ndarray.array(np.zeros((8, ), dtype='int32'))
+        module(a_, b_, c_)
+        tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
+
 if __name__ == "__main__":
     test_llvm_import()
     test_alignment()
@@ -567,3 +598,4 @@ if __name__ == "__main__":
     test_llvm_div()
     test_llvm_fp_math()
     test_dwarf_debug_information()
+    test_llvm_shuffle()