if (!dstType)
return failure();
+ // Get the scalar float type.
+ FloatType scalarFloatType;
+ if (auto scalarType = powfOp.getType().dyn_cast<FloatType>()) {
+ scalarFloatType = scalarType;
+ } else if (auto vectorType = powfOp.getType().dyn_cast<VectorType>()) {
+ scalarFloatType = vectorType.getElementType().cast<FloatType>();
+ } else {
+ return failure();
+ }
+
+ // Get int type of the same shape as the float type.
+ Type scalarIntType = rewriter.getIntegerType(32);
+ Type intType = scalarIntType;
+ if (auto vectorType = adaptor.getRhs().getType().dyn_cast<VectorType>()) {
+ auto shape = vectorType.getShape();
+ intType = VectorType::get(shape, scalarIntType);
+ }
+
// Per GL Pow extended instruction spec:
// "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
Location loc = powfOp.getLoc();
Value lessThan =
rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
+
+ // TODO: The following just forcefully casts y into an integer value in
+ // order to properly propagate the sign, assuming integer y cases. It
+ // doesn't cover other cases and should be fixed.
+
+ // Cast exponent to integer and calculate exponent % 2 != 0.
+ Value intRhs =
+ rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+ Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
+ Value bitwiseAndOne =
+ rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
+ Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+
+ // calculate pow based on abs(lhs)^rhs.
Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
+ // if the exponent is odd and lhs < 0, negate the result.
+ Value shouldNegate =
+ rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
+ pow);
return success();
}
};
// CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
// CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+ // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
+ // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+ // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
+ // CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32
// CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32
// CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32
- // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32
+ // CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1
+ // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32
%0 = math.powf %lhs, %rhs : f32
// CHECK: return %[[SEL]]
return %0: f32
func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
// CHECK: spirv.FOrdLessThan
// CHECK: spirv.GL.FAbs
+ // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
+ // CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
// CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32>
// CHECK: spirv.FNegate
// CHECK: spirv.Select