return failure();
Location loc = countOp.getLoc();
- Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc);
- Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+ Value input = adaptor.getOperand();
+ Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
- Value msb =
- rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
- // We need to subtract from 31 given that the index is from the least
- // significant bit.
- Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb);
- // If the integer has all zero bits, GLSL FindUMsb would return -1. So
- // theoretically (31 - FindUMsb) should still give the correct result.
- // However, certain Vulkan implementations have driver bugs regarding it.
- // So handle the corner case explicity to workaround it.
- Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub);
+ Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+
+ Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input);
+ // We need to subtract from 31 given that the index returned by GLSL
+ // FindUMsb is counted from the least significant bit. Theoretically this
+ // also gives the correct result even if the integer has all zero bits, in
+ // which case GLSL FindUMsb would return -1.
+ Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+ // However, certain Vulkan implementations have driver bugs for the corner
+ // case where the input is zero. And.. it can be smart to optimize a select
+ // only involving the corner case. So separately compute the result when the
+ // input is either zero or one.
+ Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
+ Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
+ subMsb);
return success();
}
};
// CHECK-LABEL: @ctlz_scalar
// CHECK-SAME: (%[[VAL:.+]]: i32)
func.func @ctlz_scalar(%val: i32) -> i32 {
- // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32
- // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
+ // CHECK-DAG: %[[V1:.+]] = spv.Constant 1 : i32
// CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32
+ // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
- // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
- // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32
- // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32
+ // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
+ // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : i32
+ // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : i32
+ // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : i1, i32
// CHECK: return %[[R]]
%0 = math.ctlz %val : i32
return %0 : i32
func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
// CHECK: spv.GLSL.FindUMsb
// CHECK: spv.ISub
- // CHECK: spv.IEqual
+ // CHECK: spv.ULessThanEqual
// CHECK: spv.Select
%0 = math.ctlz %val : vector<1xi32>
return %0 : vector<1xi32>
// CHECK-LABEL: @ctlz_vector2
// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
- // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32>
- // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
+ // CHECK-DAG: %[[V1:.+]] = spv.Constant dense<1> : vector<2xi32>
// CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
+ // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
- // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
- // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32>
- // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32>
- // CHECK: return %[[R]]
+ // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
+ // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : vector<2xi32>
+ // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : vector<2xi32>
+ // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : vector<2xi1>, vector<2xi32>
%0 = math.ctlz %val : vector<2xi32>
return %0 : vector<2xi32>
}