[Optimization] Warp level reduction support for CUDA (#5498)
authorWei Pan <60017475+wpan11nv@users.noreply.github.com>
Sat, 9 May 2020 15:38:38 +0000 (08:38 -0700)
committerGitHub <noreply@github.com>
Sat, 9 May 2020 15:38:38 +0000 (08:38 -0700)
- Added the warp level reduction support

- Upgraded shfl intrinsics to the sync version.

- This is the building block for scheduling softmax like operations.

Signed-off-by: Wei Pan <weip@nvidia.com>
include/tvm/tir/expr.h
src/target/source/codegen_cuda.cc
src/target/source/codegen_cuda.h
src/target/source/intrin_rule_cuda.cc
src/target/source/intrin_rule_opencl.cc
src/target/source/literal/cuda_half_t.h
src/tir/transforms/lower_thread_allreduce.cc
src/tir/transforms/lower_warp_memory.cc
tests/python/integration/test_reduce.py

index bf0d4f9..afa9414 100644 (file)
@@ -1234,22 +1234,43 @@ constexpr const char *tvm_call_trace_packed_lowered =
  *  }
  */
 constexpr const char* tvm_storage_sync = "tvm_storage_sync";
+
 /*!
  * \brief See pseudo code
  *
- *  Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
- *     return (value passed in by warp indicated by warp_id);
+ *  Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
+ *    return (value passed in by warp indicated by this_warp_id);
+ *  }
+ *
+ *  Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
+ *    return (value passed in by warp indicated by this_warp_id - offset);
+ *  }
+ *
+ *  Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
+ *    return (value passed in by warp indicated by this_warp_id + offset);
+ *  }
+ *
+ *  unsigned tvm_warp_activemask() {
+ *    return (32-bit mask of currently active threads in the calling warp);
  *  }
  *
  *  Parameter warp_id indicates the source thread ID in a warp.
  *
+ *  Parameter offset indicates the relative distance to this_warp_id.
+ *
  *  Parameter width indicates the number of threads involved in one
- *  shuffle. See CUDA document for __shfl.
+ *  shuffle. See CUDA document for __shfl_sync, __shfl_up_sync,
+ *  __shfl_down_sync and __activemask.
  *
  *  Parameter warp_size is the size of a warp, which helps a backend
  *  to determine wheter the width paramter is legal.
+ *
  */
 constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
+constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up";
+constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down";
+constexpr const char* tvm_warp_activemask = "tvm_warp_activemask";
+
 /*!
  * \brief Initialize the global barrier.
  *  Call this at beginning of kernel that need global barrier.
index a911e6b..591e4d0 100644 (file)
@@ -64,6 +64,10 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << _cuda_half_util;
   }
 
+  if (enable_warp_shuffle_) {
+    decl_stream << _cuda_warp_intrinsic_util;
+  }
+
   if (enable_int8_) {
     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
     decl_stream << "#include <sm_61_intrinsics.h>\n";
@@ -269,6 +273,11 @@ void CodeGenCUDA::PrintVecBinaryOp(
 
 void CodeGenCUDA::PrintVecElemLoad(
     const std::string& vec, DataType t, int i, std::ostream& os) {  // NOLINT(*)
+  if (t.is_scalar()) {
+    os << vec;
+    return;
+  }
+
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
@@ -395,7 +404,15 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
   os << sret;
 }
 
-void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
+void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
+  // This is only for backward compatibility with __shfl_{up/down}.
+  // A macro will be used to replace *_sync calls to legacy ones.
+  if (op->is_intrinsic("__shfl_sync") ||
+      op->is_intrinsic("__shfl_up_sync") ||
+      op->is_intrinsic("__shfl_down_sync")) {
+    enable_warp_shuffle_ = true;
+  }
+
   if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
     need_mma_h_ = true;
     CHECK_EQ(op->args.size(), 6U);
index d1db704..ed17638 100644 (file)
@@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
   bool enable_fp16_{false};
   // whether enable int8
   bool enable_int8_{false};
+  // whether enable warp shuffle intrinsics
+  bool enable_warp_shuffle_{false};
   // whether need math_constants.h
   bool need_math_constants_h_{false};
   // whether need mma.h
index f40dd5e..47425c3 100644 (file)
@@ -81,14 +81,34 @@ struct CUDAPopcount {
   }
 };
 
+
+struct CUDAWarpIntrinsic {
+  const char* operator()(DataType t, const std::string& name) const {
+    if (name == intrinsic::tvm_warp_shuffle) {
+      return "__shfl_sync";
+    }
+    if (name == intrinsic::tvm_warp_shuffle_up) {
+      return "__shfl_up_sync";
+    }
+    if (name == intrinsic::tvm_warp_shuffle_down) {
+      return "__shfl_down_sync";
+    }
+    if (name == intrinsic::tvm_warp_activemask) {
+      return "__activemask";
+    }
+    return "";
+  }
+};
+
+template <typename T>
 static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
   PrimExpr e = args[0];
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
-  CHECK_EQ(call->args.size(), 4);  // value, warp_id, width, warp_size
+  CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
   Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
-  *rv = CallNode::make(
-      call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
+  const char* name = T()(call->dtype, call->name);
+  *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
@@ -158,7 +178,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
 .set_body(DispatchExtern<CUDAPopcount>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
-.set_body(DispatchCUDAShuffle);
+.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up")
+.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down")
+.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
+.set_body(DispatchExtern<CUDAWarpIntrinsic>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
 .set_body(DispatchExtern<CUDAMath>);
index 7374e6d..d7f63a6 100644 (file)
@@ -94,13 +94,13 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
   PrimExpr e = args[0];
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
-  CHECK_EQ(call->args.size(), 4);  // value, warp_id, width, warp_size
+  CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
   arith::Analyzer analyzer;
-  CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
+  CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
     << "Intel warp shuffle dose not support width != warp_size";
-  Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
-  *rv = CallNode::make(
-      call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
+  Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
+  *rv = CallNode::make(call->dtype, "intel_sub_group_shuffle",
+    opencl_args, CallNode::PureExtern);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
index 858ac85..baf4ba7 100644 (file)
@@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) {
 }
 )";
 
+static constexpr const char* _cuda_warp_intrinsic_util = R"(
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
+#define __shfl_sync(mask, var, lane, width) \
+        __shfl((var), (lane), (width))
+
+#define __shfl_down_sync(mask, var, offset, width) \
+        __shfl_down((var), (offset), (width))
+
+#define __shfl_up_sync(mask, var, offset, width) \
+        __shfl_up((var), (offset), (width))
+#endif
+
+)";
+
 #endif  // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
index 9cb817d..11e420b 100644 (file)
@@ -39,8 +39,8 @@ namespace tir {
 
 class ThreadAllreduceBuilder final : public StmtExprMutator {
  public:
-  explicit ThreadAllreduceBuilder(int warp_size)
-      : warp_size_(warp_size) {}
+  explicit ThreadAllreduceBuilder(const TargetNode* target)
+    : target_(target), warp_size_(target->thread_warp_size) {}
 
   Stmt VisitStmt_(const AttrStmtNode *op) final {
     if (op->attr_key == attr::thread_extent) {
@@ -84,15 +84,22 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     auto it = alloc_remap_.find(op->buffer_var.get());
     if (it != alloc_remap_.end()) {
       const AllocateNode* repl = it->second.as<AllocateNode>();
-      // use volatile access to shared buffer.
-      stmt = AttrStmtNode::make(
-          repl->buffer_var, attr::volatile_scope, 1, op->body);
-      stmt = AllocateNode::make(
-          repl->buffer_var, repl->dtype,
-          repl->extents, repl->condition, stmt);
-      stmt = AttrStmtNode::make(
-          repl->buffer_var, attr::storage_scope,
-          StringImmNode::make("shared"), stmt);
+      if (warp_allocs_.count(repl)) {
+        stmt = AllocateNode::make(repl->buffer_var, repl->dtype,
+            repl->extents, repl->condition, op->body);
+        stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope,
+            StringImmNode::make("local"), stmt);
+      } else {
+        // use volatile access to shared buffer.
+        stmt = AttrStmtNode::make(
+            repl->buffer_var, attr::volatile_scope, 1, op->body);
+        stmt = AllocateNode::make(
+            repl->buffer_var, repl->dtype,
+            repl->extents, repl->condition, stmt);
+        stmt = AttrStmtNode::make(
+            repl->buffer_var, attr::storage_scope,
+            StringImmNode::make("shared"), stmt);
+      }
       return stmt;
     } else {
       return stmt;
@@ -119,6 +126,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
       return scope.dim_index < other.scope.dim_index;
     }
   };
+
   // make allreduce.
   Stmt MakeAllreduce(const CallNode* call) {
     CHECK(!reduce_combiner_.empty());
@@ -131,7 +139,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     Array<PrimExpr> inits = combiner->identity_element;
     std::vector<PrimExpr> values(size);
     std::vector<DataType> types(size);
-    PrimExpr cond  = call->args[size+1];
+    PrimExpr cond = call->args[size+1];
     for (size_t idx = 0; idx < size; ++idx) {
       values[idx] = call->args[1+idx];
       if (!is_one(cond)) {
@@ -181,52 +189,196 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     std::sort(vpar.begin(), vpar.end());
     // the size of each index.
     int reduce_extent, group_extent;
-    int threadx_extent = 1;
     PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
     PrimExpr group_index = FlattenThread(vpar, &group_extent);
-    if (reduce_extent == 1) {
-      // special case, no reduction is needed.
-      std::vector<Stmt> stores(size);
+    std::vector<Stmt> seq;
+    std::vector<Var> shared_bufs(size);
+    std::vector<Stmt> local_vars;
+    //
+    // This is an optimization. For small reduction sizes, it may be beneficial
+    // for a single warp to performance the entire reduction. No trips to shared
+    // memory and no cross warp synchronizations are required.
+    // The following code emits the reduction as follows:
+    //
+    // Allocate reduction vars v[i], i = 0..size-1
+    //
+    // for offset from 16 to 1 by 2
+    //
+    //   a    <- load(v[i])
+    //   b    <- shuffle_down(load(v[i], offset))
+    //   v[i] <- reduction(a, b)
+    //
+    // broadcast results from lane 0 to all other lanes and store
+    // the final reduction result to the proper location.
+    //
+    if (is_warp_reduction(types)) {
+      // TODO(tvm-team) sub-warp reduction support.
+      CHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction";
+      //
+      // This is the index to the reduction variable, one reduction
+      // variable per warp. Local scope seems easier to reason without
+      // relying on a pattern match pass to fix it later.
+      PrimExpr index(0);
+
+      for (size_t idx = 0; idx < size; ++idx) {
+        shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
+        PrimExpr pred = const_true(types[idx].lanes());
+        seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred));
+
+        // Uses a local variable to store the shuffled data.
+        // Later on, this allocation will be properly attached to this statement.
+        Var var("t" + std::to_string(idx), types[idx]);
+        Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred,
+                                    EvaluateNode::make(0));
+        local_vars.push_back(s);
+      }
+
+      // The mask for this reducer, as this reducer may sit inside
+      // a divergent control flow. Here it uses a variable to cache the current
+      // active channels.
+      //
+      Var mask_var("mask", DataType::UInt(32));
+      {
+        PrimExpr pred = const_true(1);
+        PrimExpr mask = CallNode::make(DataType::UInt(32),
+          intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
+        seq.emplace_back(StoreNode::make(mask_var, mask, index, pred));
+        // Push allocation with an empty body. Later this will be fixed
+        // when the entire body is ready.
+        auto stmt = AllocateNode::make(mask_var, mask_var->dtype,
+          {PrimExpr(1)}, pred, EvaluateNode::make(0));
+        local_vars.push_back(stmt);
+      }
+
+      // Emit reductions within a warp.
+      for (int offset = 16; offset > 0; offset /= 2) {
+        // Load reduction values, no synchronization needed.
+        Array<PrimExpr> a, b;
+        for (size_t i = 0; i < size; ++i) {
+          Var var = shared_bufs[i];
+          PrimExpr pred = const_true(types[i].lanes());
+          PrimExpr val = LoadNode::make(types[i], var, index, pred);
+          a.push_back(val);
+
+          // __shfl_*sync calls shall not appear in if_then_else expressions
+          // as this is causing extra divergency. E.g.
+          //
+          // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
+          //
+          // behaves differently from
+          //
+          // int t = __shfl_sync(mask, v1, 0);
+          // v1 = (v2 < v3) ? v3 : t;
+          //
+          // The former may cause dead lock as there is a divergent
+          // branch with a warp sync call inside.
+          //
+          const char* shfl_func = intrinsic::tvm_warp_shuffle_down;
+          PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset);
+          const AllocateNode* repl = local_vars[i].as<AllocateNode>();
+          Stmt s = StoreNode::make(repl->buffer_var, other, index, pred);
+          seq.push_back(s);
+
+          PrimExpr load = LoadNode::make(types[i], repl->buffer_var, index, pred);
+          b.push_back(load);
+        }
+
+        // Do reductions.
+        Array<PrimExpr> ret = (*combiner)(a, b);
+
+        // Store the reduction result to itself.
+        std::vector<Stmt> stores(size);
+        for (size_t i = 0; i < size; ++i) {
+          Var var = shared_bufs[i];
+          PrimExpr pred = const_true(types[i].lanes());
+          stores[i] = StoreNode::make(var, ret[i], index, pred);
+        }
+        seq.push_back(SeqStmt::Flatten(stores));
+      }
+
+      // Broadcast the reduction result from lane 0 to all other lanes.
+      // This avoids to emit predicated stores, as all threads are
+      // uniformmly writting the same result.
+      //
       for (size_t i = 0; i < size; ++i) {
+        Var var = shared_bufs[i];
         PrimExpr pred = const_true(types[i].lanes());
-        Var buffer_var = Downcast<Var>(call->args[2+size+i]);
-        stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
+        const char* shfl_func = intrinsic::tvm_warp_shuffle;
+        PrimExpr val = LoadNode::make(types[i], var, index, pred);
+        PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0);
+        seq.push_back(StoreNode::make(var, splat, index, pred));
+      }
+
+      // Update existing allocations.
+      for (size_t i = 0; i < size; ++i) {
+        CHECK(!load_remap_.count(buffers[i]));
+        PrimExpr pred = const_true(types[i].lanes());
+        Var var = shared_bufs[i];
+        load_remap_[buffers[i]] = LoadNode::make(types[i], var, index, pred);
+        Array<PrimExpr> extents{PrimExpr(1)};
+        auto node = AllocateNode::make(var, types[i], extents, pred,
+                                       EvaluateNode::make(0));
+        alloc_remap_[buffers[i]] = node;
+        warp_allocs_.insert(node.get());
+      }
+    } else {
+      int threadx_extent = 1;
+      if (reduce_extent == 1) {
+        // special case, no reduction is needed.
+        std::vector<Stmt> stores(size);
+        for (size_t i = 0; i < size; ++i) {
+          PrimExpr pred = const_true(types[i].lanes());
+          Var buffer_var = Downcast<Var>(call->args[2+size+i]);
+          stores[i] = StoreNode::make(buffer_var, values[i], 0, pred);
+        }
+        return SeqStmt::Flatten(stores);
+      }
+      // Whether the threadIdx.x is involved in reduction.
+      if (vred[0].scope.dim_index == 0) {
+        threadx_extent = vred[0].extent;
+      }
+      // This sync is necessary because there might be incomplete read of
+      // previous iteration on the same buffer.
+      seq.emplace_back(SyncThread("shared"));
+      for (size_t idx = 0; idx < size; ++idx) {
+        shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
+        PrimExpr pred = const_true(types[idx].lanes());
+        seq.emplace_back(StoreNode::make(
+            shared_bufs[idx], values[idx],
+            BufIndex(reduce_index, group_index, reduce_extent), pred));
+      }
+      seq.emplace_back(SyncThread("shared"));
+      seq.emplace_back(MakeBufAllreduce(
+          combiner, types, shared_bufs,
+          reduce_index, group_index, reduce_extent, threadx_extent));
+      for (size_t idx = 0; idx < size; ++idx) {
+        CHECK(!load_remap_.count(buffers[idx]));
+        PrimExpr pred = const_true(types[idx].lanes());
+        load_remap_[buffers[idx]] = LoadNode::make(
+          types[idx], shared_bufs[idx],
+          BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
+        alloc_remap_[buffers[idx]] = AllocateNode::make(
+          shared_bufs[idx], types[idx],
+          {PrimExpr(group_extent), PrimExpr(reduce_extent)},
+          pred, EvaluateNode::make(0));
       }
-      return SeqStmt::Flatten(stores);
-    }
-    // Whether the threadIdx.x is involved in reduction.
-    if (vred[0].scope.dim_index == 0) {
-      threadx_extent = vred[0].extent;
-    }
-    std::vector<Stmt> seq;
-    std::vector<Var> shared_bufs(size);
-    // This sync is necessary because there might be incomplete read of
-    // previous iteration on the same buffer.
-    seq.emplace_back(SyncThread("shared"));
-    for (size_t idx = 0; idx < size; ++idx) {
-      shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle());
-      PrimExpr pred = const_true(types[idx].lanes());
-      seq.emplace_back(StoreNode::make(
-          shared_bufs[idx], values[idx],
-          BufIndex(reduce_index, group_index, reduce_extent), pred));
     }
-    seq.emplace_back(SyncThread("shared"));
-    seq.emplace_back(MakeBufAllreduce(
-        combiner, types, shared_bufs,
-        reduce_index, group_index, reduce_extent, threadx_extent));
-    for (size_t idx = 0; idx < size; ++idx) {
-      CHECK(!load_remap_.count(buffers[idx]));
-      PrimExpr pred = const_true(types[idx].lanes());
-      load_remap_[buffers[idx]] = LoadNode::make(
-        types[idx], shared_bufs[idx],
-        BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred);
-      alloc_remap_[buffers[idx]] = AllocateNode::make(
-        shared_bufs[idx], types[idx],
-        {PrimExpr(group_extent), PrimExpr(reduce_extent)},
-        pred, EvaluateNode::make(0));
+
+    // Fix all local allocations as all statements are built.
+    Stmt body = SeqStmt::Flatten(seq);
+    for (auto var : local_vars) {
+      const AllocateNode* repl = var.as<AllocateNode>();
+      if (repl) {
+        body = AllocateNode::make(repl->buffer_var, repl->dtype,
+            repl->extents, repl->condition, body);
+        body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope,
+            StringImmNode::make("local"), body);
+      }
     }
-    return SeqStmt::Flatten(seq);
+
+    return body;
   }
+
   // make allreduce.
   Stmt MakeBufAllreduce(const CommReducerNode *combiner,
                         const std::vector<DataType>& types,
@@ -330,6 +482,59 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
                    {StringImmNode::make(sync)},
                    CallNode::Intrinsic));
   }
+
+  // Emit warp shuffle intrinsic calls.
+  PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val,
+                       int delta_or_lane) {
+    PrimExpr pred = const_true(1);
+    PrimExpr index(0);
+    PrimExpr mask = LoadNode::make(DataType::UInt(32), mask_var, index, pred);
+    PrimExpr width = IntImm(DataType::Int(32), warp_size_);
+    Array<PrimExpr> args{mask, val, IntImm(DataType::Int(32), delta_or_lane),
+                         width, width};
+    return CallNode::make(val.dtype(), name, args, CallNode::Intrinsic);
+  }
+
+  // Check if this is a reduction on threadIdx.x and its extent matches
+  // the warp size.
+  //
+  // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads.
+  bool is_warp_reduction(const std::vector<DataType>& types) const {
+    // Only cuda target supports warp reductions.
+    if (target_->target_name != "cuda") return false;
+
+    // Supported types:
+    // {u}int, {u}long, {u}long long, float, double, half/half2
+    if (std::any_of(types.begin(), types.end(), [](DataType ty) {
+          if (ty.is_float16()) return ty.lanes() > 2;
+          if (ty.is_vector()) return true;
+          return ty.bytes() < 4 || ty.bytes() > 8;
+        })) {
+      return false;
+    }
+    if (thread_extents_.empty()) {
+      return false;
+    }
+
+    const AttrStmtNode* op = thread_extents_.back();
+    DCHECK_EQ(op->attr_key, attr::thread_extent);
+
+    IterVar iv = Downcast<IterVar>(op->node);
+    ThreadEntry e;
+    e.scope = runtime::ThreadScope::make(iv->thread_tag);
+    e.extent = 0;
+    if (auto ptr = op->value.as<IntImmNode>()) {
+      e.extent = static_cast<int>(ptr->value);
+    }
+
+    return e.extent == warp_size_ &&
+           e.scope.dim_index == 0 &&
+           e.scope.rank == 1;
+  }
+
+  // The target.
+  const TargetNode* target_ = nullptr;
+
   // The warp size of the device.
   int warp_size_{1};
 
@@ -340,6 +545,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   std::unordered_map<const VarNode *, PrimExpr> load_remap_;
   // Allocate remap
   std::unordered_map<const VarNode *, Stmt> alloc_remap_;
+  // Allocate from warp reductions
+  std::unordered_set<const void *> warp_allocs_;
   // Internal analyzer
   arith::Analyzer analyzer_;
 };
@@ -352,7 +559,8 @@ Pass LowerThreadAllreduce() {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget);
     CHECK(target.defined())
         << "LowerThreadAllreduce: Require the target attribute";
-    n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body);
+    const TargetNode* target_node = target.as<TargetNode>();
+    n->body = ThreadAllreduceBuilder(target_node)(n->body);
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
index 516b96c..0abbe76 100644 (file)
@@ -265,10 +265,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
           << op->index << " local_index=" << local_index;
       PrimExpr load_value = LoadNode::make(
           op->dtype, op->buffer_var, local_index, op->predicate);
+      PrimExpr mask = CallNode::make(DataType::UInt(32),
+          intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic);
       return CallNode::make(load_value.dtype(),
-                        intrinsic::tvm_warp_shuffle,
-                        {load_value, group, width_, warp_size_},
-                        CallNode::Intrinsic);
+                            intrinsic::tvm_warp_shuffle,
+                            {mask, load_value, group, width_, warp_size_},
+                            CallNode::Intrinsic);
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
index 82ade44..7ac3496 100644 (file)
@@ -338,6 +338,106 @@ def test_rfactor_argmax():
     check_target("cuda")
     check_target("vulkan")
 
+def test_warp_reduction1():
+    nthx = 32
+    nthy = 4
+    block_x = te.thread_axis("blockIdx.x")
+    thread_x = te.thread_axis((0, nthx), "threadIdx.x")
+    thread_y = te.thread_axis((0, nthy), "threadIdx.y")
+
+    def check_target(device, m, n):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("skip because %s is not enabled.." % device)
+            return
+
+        # compute
+        A = te.placeholder((m, n), name='A')
+        k = te.reduce_axis((0, n))
+        B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B')
+        s = te.create_schedule(B.op)
+
+        # schedule
+        k = s[B].op.reduce_axis[0]
+        ko, _ = s[B].split(k, nparts=nthx)
+        s[B].bind(ko, thread_x)
+        xo, xi = s[B].split(s[B].op.axis[0], factor=nthy)
+        s[B].bind(xi, thread_y)
+        s[B].bind(xo, block_x)
+
+        print(tvm.lower(s, [A, B], simple_mode=True))
+
+        # validation
+        func = tvm.build(s, [A, B], "cuda", name="warp_reduction")
+        a_np = np.random.uniform(size=(m,n)).astype(A.dtype)
+        b_np = np.zeros((m,), dtype=A.dtype)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        b_np = np.max(a_np, axis=1)
+        func(a, b)
+        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
+
+    check_target("cuda", m=32, n=256)
+    check_target("cuda", m=10, n=20)
+    # This is a bug in normal reduction.
+    # check_target("cuda", m=10, n=37)
+
+def test_warp_reduction2():
+    def fcombine(x, y):
+        return x[0] + y[0], x[1] * y[1]
+
+    def fidentity(t0, t1):
+        return tvm.tir.const(0, t0), tvm.tir.const(1, t1)
+
+    add_mul_reducer = te.comm_reducer(fcombine, fidentity, name='add_mul_reducer')
+
+    # compute
+    m = 16
+    n = 256
+    A0 = te.placeholder((m, n), name='A0', dtype='float32')
+    A1 = te.placeholder((m, n), name='Al', dtype='float32')
+    k = te.reduce_axis((0, n), 'k')
+    T0, T1 = te.compute((m, ), lambda i: \
+        add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T')
+
+    nthdx, nthdy = 32, 2
+    block_x = te.thread_axis("blockIdx.x")
+    thread_x = te.thread_axis((0, nthdx), "threadIdx.x")
+    thread_y = te.thread_axis((0, nthdy), "threadIdx.y")
+
+    def check_target(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("skip because %s is not enabled.." % device)
+            return
+
+        # schedule
+        s = te.create_schedule(T0.op)
+        ko, _ = s[T0].split(k, nparts=nthdx)
+        xo, xi = s[T0].split(s[T0].op.axis[0], factor=nthdy)
+        s[T0].bind(ko, thread_x)
+        s[T0].bind(xi, thread_y)
+        s[T0].bind(xo, block_x)
+
+        # validation
+        ctx = tvm.context(device, 0)
+        a0_np = np.random.uniform(size=(m,n)).astype(A0.dtype)
+        a1_np = np.random.uniform(size=(m,n)).astype(A1.dtype)
+        t0_np = np.zeros((m,), dtype=A0.dtype)
+        t1_np = np.zeros((m,), dtype=A1.dtype)
+        a0 = tvm.nd.array(a0_np, ctx)
+        a1 = tvm.nd.array(a1_np, ctx)
+        t0 = tvm.nd.array(t0_np, ctx)
+        t1 = tvm.nd.array(t1_np, ctx)
+        func = tvm.build(s, [A0, A1, T0, T1], device, name="reduction")
+        func(a0, a1, t0, t1)
+        t0_np = np.sum(a0_np, axis=1)
+        t1_np = np.product(a1_np, axis=1)
+        tvm.testing.assert_allclose(t0.asnumpy(), t0_np, rtol=1e-3, atol=1e-3)
+        tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3)
+
+    check_target("cuda")
+
 if __name__ == "__main__":
     test_rfactor_elemwise_threads()
     test_rfactor_threads()
@@ -346,3 +446,5 @@ if __name__ == "__main__":
     test_reduce_prims()
     test_argmax()
     test_rfactor_argmax()
+    test_warp_reduction1()
+    test_warp_reduction2()