[mlir][spirv] Workaround driver bug in math.ctlz conversion again
authorLei Zhang <antiagainst@google.com>
Thu, 16 Jun 2022 14:48:33 +0000 (10:48 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 16 Jun 2022 14:53:49 +0000 (10:53 -0400)
The previous approach does not work as the Adreno driver is
clever at optimizing away the selection. So now check two
inputs together.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127930

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir

index 07c99f06ab2c313c23cf103953ed1fc2d4081c0b..ea367b1faa201e959057226d249711c864bbff07 100644 (file)
@@ -141,20 +141,25 @@ class CountLeadingZerosPattern final
       return failure();
 
     Location loc = countOp.getLoc();
-    Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc);
-    Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+    Value input = adaptor.getOperand();
+    Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
-    Value msb =
-        rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
-    // We need to subtract from 31 given that the index is from the least
-    // significant bit.
-    Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb);
-    // If the integer has all zero bits, GLSL FindUMsb would return -1. So
-    // theoretically (31 - FindUMsb) should still give the correct result.
-    // However, certain Vulkan implementations have driver bugs regarding it.
-    // So handle the corner case explicity to workaround it.
-    Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits);
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub);
+    Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
+
+    Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input);
+    // We need to subtract from 31 given that the index returned by GLSL
+    // FindUMsb is counted from the least significant bit. Theoretically this
+    // also gives the correct result even if the integer has all zero bits, in
+    // which case GLSL FindUMsb would return -1.
+    Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
+    // However, certain Vulkan implementations have driver bugs for the corner
+    // case where the input is zero. And.. it can be smart to optimize a select
+    // only involving the corner case. So separately compute the result when the
+    // input is either zero or one.
+    Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
+    Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
+                                                 subMsb);
     return success();
   }
 };
index d8126d4e956c6207b26f503b294f0317a492e86f..a3067af661f64270a34c697ed7e8224e92995101 100644 (file)
@@ -82,13 +82,14 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
 // CHECK-LABEL: @ctlz_scalar
 //  CHECK-SAME: (%[[VAL:.+]]: i32)
 func.func @ctlz_scalar(%val: i32) -> i32 {
-  // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32
-  // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
+  // CHECK-DAG: %[[V1:.+]] = spv.Constant 1 : i32
   // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
-  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
-  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32
-  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32
+  // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
+  // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : i32
+  // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : i32
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : i1, i32
   // CHECK: return %[[R]]
   %0 = math.ctlz %val : i32
   return %0 : i32
@@ -98,7 +99,7 @@ func.func @ctlz_scalar(%val: i32) -> i32 {
 func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
   // CHECK: spv.GLSL.FindUMsb
   // CHECK: spv.ISub
-  // CHECK: spv.IEqual
+  // CHECK: spv.ULessThanEqual
   // CHECK: spv.Select
   %0 = math.ctlz %val : vector<1xi32>
   return %0 : vector<1xi32>
@@ -107,14 +108,14 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
 // CHECK-LABEL: @ctlz_vector2
 //  CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
 func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
-  // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32>
-  // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
+  // CHECK-DAG: %[[V1:.+]] = spv.Constant dense<1> : vector<2xi32>
   // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
+  // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32>
   // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
-  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
-  // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32>
-  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32>
-  // CHECK: return %[[R]]
+  // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
+  // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : vector<2xi32>
+  // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : vector<2xi32>
+  // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : vector<2xi1>, vector<2xi32>
   %0 = math.ctlz %val : vector<2xi32>
   return %0 : vector<2xi32>
 }