[spirv][math] Fix sign propagation for math.powf conversion
authorDaniel Garvey <dan@nod-labs.com>
Wed, 10 May 2023 04:26:37 +0000 (21:26 -0700)
committerLei Zhang <antiagainst@google.com>
Wed, 10 May 2023 04:44:09 +0000 (21:44 -0700)
For `x^y`, the result's sign should consider whether `y` is
an integer and whether it's odd or even.

This still does not cover all corner cases regarding `x^y`
but it's an improvement over the current implementation.

Reviewed By: antiagainst, qedawkins

Differential Revision: https://reviews.llvm.org/D150234

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir

index 80b2257..412f99c 100644 (file)
@@ -305,6 +305,24 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     if (!dstType)
       return failure();
 
+    // Get the scalar float type.
+    FloatType scalarFloatType;
+    if (auto scalarType = powfOp.getType().dyn_cast<FloatType>()) {
+      scalarFloatType = scalarType;
+    } else if (auto vectorType = powfOp.getType().dyn_cast<VectorType>()) {
+      scalarFloatType = vectorType.getElementType().cast<FloatType>();
+    } else {
+      return failure();
+    }
+
+    // Get int type of the same shape as the float type.
+    Type scalarIntType = rewriter.getIntegerType(32);
+    Type intType = scalarIntType;
+    if (auto vectorType = adaptor.getRhs().getType().dyn_cast<VectorType>()) {
+      auto shape = vectorType.getShape();
+      intType = VectorType::get(shape, scalarIntType);
+    }
+
     // Per GL Pow extended instruction spec:
     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
     Location loc = powfOp.getLoc();
@@ -313,9 +331,27 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     Value lessThan =
         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
     Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
+
+    // TODO: The following just forcefully casts y into an integer value in
+    // order to properly propagate the sign, assuming integer y cases. It
+    // doesn't cover other cases and should be fixed.
+
+    // Cast exponent to integer and calculate exponent % 2 != 0.
+    Value intRhs =
+        rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+    Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
+    Value bitwiseAndOne =
+        rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
+    Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+
+    // calculate pow based on abs(lhs)^rhs.
     Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
     Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
+    // if the exponent is odd and lhs < 0, negate the result.
+    Value shouldNegate =
+        rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
+                                                 pow);
     return success();
   }
 };
index 125478e..4d0ef06 100644 (file)
@@ -137,9 +137,14 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
   // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
   // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
   // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+  // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
+  // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+  // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
+  // CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32
   // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32
   // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32
-  // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32
+  // CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1
+  // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32
   %0 = math.powf %lhs, %rhs : f32
   // CHECK: return %[[SEL]]
   return %0: f32
@@ -149,6 +154,8 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
 func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
   // CHECK: spirv.FOrdLessThan
   // CHECK: spirv.GL.FAbs
+  // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
+  // CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
   // CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32>
   // CHECK: spirv.FNegate
   // CHECK: spirv.Select