ConversionPatternRewriter &rewriter) const override;
};
-/// Converts arith.remsi to SPIR-V ops.
+/// Converts arith.remsi to GLSL SPIR-V ops.
///
/// This cannot be merged into the template unary/binary pattern due to Vulkan
/// restrictions over spv.SRem and spv.SMod.
-struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> {
+struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
+ using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
LogicalResult
}
//===----------------------------------------------------------------------===//
-// RemSIOpPattern
+// RemSIOpGLSLPattern
//===----------------------------------------------------------------------===//
/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
/// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
/// if either operand can be negative. Emulate it via spv.UMod.
+template <typename SignedAbsOp>
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
Value signOperand, OpBuilder &builder) {
assert(lhs.getType() == rhs.getType());
Type type = lhs.getType();
// Calculate the remainder with spv.UMod.
- Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
- Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+ Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
+ Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
// Fix the sign.
}
LogicalResult
-RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const {
- Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0],
- adaptor.getOperands()[1],
- adaptor.getOperands()[0], rewriter);
+RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
+ rewriter.replaceOp(op, result);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RemSIOpOCLPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
+ op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+ adaptor.getOperands()[0], rewriter);
rewriter.replaceOp(op, result);
return success();
spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
- RemSIOpPattern,
+ RemSIOpGLSLPattern, RemSIOpOCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func @float32_unary_scalar(%arg0: f32) {
- // CHECK: spv.GLSL.FAbs %{{.*}}: f32
- %0 = math.abs %arg0 : f32
- // CHECK: spv.GLSL.Ceil %{{.*}}: f32
- %1 = math.ceil %arg0 : f32
// CHECK: spv.FNegate %{{.*}}: f32
- %5 = arith.negf %arg0 : f32
- // CHECK: spv.GLSL.Floor %{{.*}}: f32
- %10 = math.floor %arg0 : f32
+ %0 = arith.negf %arg0 : f32
return
}
}
} // end module
+
+// -----
+
+// Check OpenCL lowering of arith.remsi
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Int16, Kernel], []>, {}>
+} {
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+ // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : i32
+ // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : i32
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+ %0 = arith.remsi %lhs, %rhs: i32
+ return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+ // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : vector<3xi16>
+ // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : vector<3xi16>
+ // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+ // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+ // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+ // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+ %0 = arith.remsi %arg0, %arg1: vector<3xi16>
+ return
+}
+
+} // end module