let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
}
-def NVVM_ShflBflyOp :
- NVVM_Op<"shfl.sync.bfly">,
+def NVVM_ShflDownOp :
+ NVVM_Op<"shfl.sync.down">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$dst,
LLVM_Type:$val,
LLVM_Type:$mask_and_clamp,
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
string llvmBuilder = [{
- auto intId = getShflBflyIntrinsicId(
+ auto intId = getShflDownIntrinsicId(
$_resultType, static_cast<bool>($return_value_and_is_valid));
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
- let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
+ let parser = [{ return parseNVVMShflSyncDownOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let verifier = [{
if (!getAttrOfType<UnitAttr>("return_value_and_is_valid"))
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
- Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
+ Value *shfl = rewriter.create<NVVM::ShflDownOp>(
loc, shflTy, activeMask, value, offset, maskAndClamp,
returnValueAndIsValidAttr);
Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
for (int i = 1; i < kWarpSize; i <<= 1) {
Value *offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
- Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
+ Value *shflValue = rewriter.create<NVVM::ShflDownOp>(
loc, type, activeMask, value, offset, maskAndClamp,
/*return_value_and_is_valid=*/UnitAttr());
value = accumFactory(loc, value, shflValue, rewriter);
}
// <operation> ::=
-// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
+// `llvm.nvvm.shfl.sync.down %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
-static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
+static ParseResult parseNVVMShflSyncDownOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> ops;
Type resultType;
return builder.CreateCall(fn, args);
}
-static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
+static llvm::Intrinsic::ID getShflDownIntrinsicId(llvm::Type *resultType,
bool withPredicate) {
if (withPredicate) {
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
- return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
- : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
}
- return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
- : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
+ return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
+ : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
}
class ModuleTranslation : public LLVM::ModuleTranslation {
attributes { gpu.kernel } {
%arg0 = constant 1.0 : f32
// TODO(csigg): Check full IR expansion once lowering has settled.
- // CHECK: nvvm.shfl.sync.bfly
+ // CHECK: nvvm.shfl.sync.down
// CHECK: nvvm.barrier0
// CHECK: llvm.fadd
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
attributes { gpu.kernel } {
%arg0 = constant 1 : i32
// TODO(csigg): Check full IR expansion once lowering has settled.
- // CHECK: nvvm.shfl.sync.bfly
+ // CHECK: nvvm.shfl.sync.down
// CHECK: nvvm.barrier0
%result = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : i32, %rhs : i32):
// CHECK-LABEL: @nvvm_invalid_shfl_pred_1
func @nvvm_invalid_shfl_pred_1(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
+ %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.i32
}
// -----
// CHECK-LABEL: @nvvm_invalid_shfl_pred_2
func @nvvm_invalid_shfl_pred_2(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
+ %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32 }">
}
// -----
// CHECK-LABEL: @nvvm_invalid_shfl_pred_3
func @nvvm_invalid_shfl_pred_3(%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32, %arg3 : !llvm.i32) {
// expected-error@+1 {{expected return type !llvm<"{ ?, i1 }">}}
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
+ %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i32 }">
}
// -----
func @nvvm_shfl(
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm.i32 {
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : !llvm.i32
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
- %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : !llvm.float
+ // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i32
+ %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 : !llvm.i32
+ // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.float
+ %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 : !llvm.float
llvm.return %0 : !llvm.i32
}
func @nvvm_shfl_pred(
%arg0 : !llvm.i32, %arg1 : !llvm.i32, %arg2 : !llvm.i32,
%arg3 : !llvm.i32, %arg4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
- %0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
- // CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
- %1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
+ // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ i32, i1 }">
+ %0 = nvvm.shfl.sync.down %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
+ // CHECK: nvvm.shfl.sync.down %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm<"{ float, i1 }">
+ %1 = nvvm.shfl.sync.down %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %0 : !llvm<"{ i32, i1 }">
}
llvm.func @nvvm_shfl(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm.i32 {
- // CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : !llvm.i32
- // CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : !llvm.float
+ // CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.shfl.sync.down %0, %3, %1, %2 : !llvm.i32
+ // CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.shfl.sync.down %0, %4, %1, %2 : !llvm.float
llvm.return %6 : !llvm.i32
}
llvm.func @nvvm_shfl_pred(
%0 : !llvm.i32, %1 : !llvm.i32, %2 : !llvm.i32,
%3 : !llvm.i32, %4 : !llvm.float) -> !llvm<"{ i32, i1 }"> {
- // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
- // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
- %7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
+ // CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %6 = nvvm.shfl.sync.down %0, %3, %1, %2 {return_value_and_is_valid} : !llvm<"{ i32, i1 }">
+ // CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %7 = nvvm.shfl.sync.down %0, %4, %1, %2 {return_value_and_is_valid} : !llvm<"{ float, i1 }">
llvm.return %6 : !llvm<"{ i32, i1 }">
}