Switch from shfl.bfly to shfl.down.
authorChristian Sigg <csigg@google.com>
Thu, 12 Dec 2019 09:27:27 +0000 (01:27 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Dec 2019 09:28:01 +0000 (01:28 -0800)
Both work for the current use case, but the latter allows implementing
prefix sums and is a little easier to understand for partial warps.

PiperOrigin-RevId: 285145287

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/nvvmir.mlir

index c6e89af..9a0e43e 100644 (file)
@@ -90,8 +90,8 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
   let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
 }
 
-def NVVM_ShflBflyOp :
-  NVVM_Op<"shfl.sync.bfly">,
+def NVVM_ShflDownOp :
+  NVVM_Op<"shfl.sync.down">,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$dst,
                  LLVM_Type:$val,
@@ -99,12 +99,12 @@ def NVVM_ShflBflyOp :
                  LLVM_Type:$mask_and_clamp,
                  OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
   string llvmBuilder = [{
-      auto intId = getShflBflyIntrinsicId(
+      auto intId = getShflDownIntrinsicId(
           $_resultType, static_cast<bool>($return_value_and_is_valid));
       $res = createIntrinsicCall(builder,
           intId, {$dst, $val, $offset, $mask_and_clamp});
   }];
-  let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
+  let parser = [{ return parseNVVMShflSyncDownOp(parser, result); }];
   let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
   let verifier = [{
     if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
index ccad2cd..a7474cf 100644 (file)
@@ -337,7 +337,7 @@ private:
           for (int i = 1; i < kWarpSize; i <<= 1) {
             Value *offset = rewriter.create<LLVM::ConstantOp>(
                 loc, int32Type, rewriter.getI32IntegerAttr(i));
-            Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
+            Value *shfl = rewriter.create<NVVM::ShflDownOp>(
                 loc, shflTy, activeMask, value, offset, maskAndClamp,
                 returnValueAndIsValidAttr);
             Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
@@ -366,7 +366,7 @@ private:
           for (int i = 1; i < kWarpSize; i <<= 1) {
             Value *offset = rewriter.create<LLVM::ConstantOp>(
                 loc, int32Type, rewriter.getI32IntegerAttr(i));
-            Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
+            Value *shflValue = rewriter.create<NVVM::ShflDownOp>(
                 loc, type, activeMask, value, offset, maskAndClamp,
                 /*return_value_and_is_valid=*/UnitAttr());
             value = accumFactory(loc, value, shflValue, rewriter);
index 0b10391..32bb3a6 100644 (file)
@@ -70,9 +70,9 @@ static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) {
 }
 
 // <operation> ::=
-//     `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
+//     `llvm.nvvm.shfl.sync.down %dst, %val, %offset, %clamp_and_mask`
 //      ({return_value_and_is_valid})? : result_type
-static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
+static ParseResult parseNVVMShflSyncDownOp(OpAsmParser &parser,
                                            OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 8> ops;
   Type resultType;
index 606e91b..cdaa384 100644 (file)
@@ -44,15 +44,15 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
   return builder.CreateCall(fn, args);
 }
 
-static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
+static llvm::Intrinsic::ID getShflDownIntrinsicId(llvm::Type *resultType,
                                                   bool withPredicate) {
   if (withPredicate) {
     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
-    return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
-                                   : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+    return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+                                   : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
   }
-  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
-                                 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+  return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+                                 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
 }
 
 class ModuleTranslation : public LLVM::ModuleTranslation {
index 30bba48..a68c775 100644 (file)
@@ -44,7 +44,7 @@ module attributes {gpu.kernel_module} {
       attributes { gpu.kernel } {
     %arg0 = constant 1.0 : f32
     // TODO(csigg): Check full IR expansion once lowering has settled.
-    // CHECK: nvvm.shfl.sync.bfly
+    // CHECK: nvvm.shfl.sync.down
     // CHECK: nvvm.barrier0
     // CHECK: llvm.fadd
     %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
@@ -61,7 +61,7 @@ module attributes {gpu.kernel_module} {
       attributes { gpu.kernel } {
     %arg0 = constant 1 : i32
     // TODO(csigg): Check full IR expansion once lowering has settled.
-    // CHECK: nvvm.shfl.sync.bfly
+    // CHECK: nvvm.shfl.sync.down
     // CHECK: nvvm.barrier0
     %result = "gpu.all_reduce"(%arg0) ({
     ^bb(%lhs : i32, %rhs : i32):
index c7d4ff4..d6df074 100644 (file)
@@ -268,7 +268,7 @@ func @null_non_llvm_type() {
 // CHECK-LABEL: @nvvm_invalid_shfl_pred_1
 func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
   // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
+  %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
 }
 
 // -----
@@ -276,7 +276,7 @@ func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
 // CHECK-LABEL: @nvvm_invalid_shfl_pred_2
 func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
   // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
+  %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
 }
 
 // -----
@@ -284,7 +284,7 @@ func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !ll
 // CHECK-LABEL: @nvvm_invalid_shfl_pred_3
 func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
   // expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
+  %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
 }
 
 // -----
index 3858ace..50a7f7d 100644 (file)
@@ -37,20 +37,20 @@ func @llvm.nvvm.barrier0() {
 func @nvvm_shfl(
     %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
     %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
-  %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
+  // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
+  %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 : !llvm.i32
+  // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
+  %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 : !llvm.float
   llvm.return %0 : !llvm.i32
 }
 
 func @nvvm_shfl_pred(
     %arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
     %arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
-  %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
-  // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
-  %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
+  // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
+  %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
+  // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
+  %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
   llvm.return %0 : !llvm<"{ i32, i1 }">
 }
 
index 2e63ecd..e97603a 100644 (file)
@@ -41,20 +41,20 @@ llvm.func @llvm.nvvm.barrier0() {
 llvm.func @nvvm_shfl(
     %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
     %3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
-  // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
-  // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
+  // CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.shfl.sync.down %0, %3, %1, %2 : !llvm.i32
+  // CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.shfl.sync.down %0, %4, %1, %2 : !llvm.float
   llvm.return %6 : !llvm.i32
 }
 
 llvm.func @nvvm_shfl_pred(
     %0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
     %3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
-  // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
-  // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
-  %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
+  // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %6 = nvvm.shfl.sync.down %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
+  // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  %7 = nvvm.shfl.sync.down %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
   llvm.return %6 : !llvm<"{ i32, i1 }">
 }