[mlir][spirv] arith::RemSIOp OpenCL lowering
authorButygin <ivan.butygin@intel.com>
Thu, 28 Oct 2021 16:04:35 +0000 (19:04 +0300)
committerButygin <ivan.butygin@intel.com>
Thu, 25 Nov 2021 09:44:06 +0000 (12:44 +0300)
Differential Revision: https://reviews.llvm.org/D114524

mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

index 6fd69637df1d8584cc476012ae0a54e9a79dd034..99e15a2f25a11aa34f0fb2cffeb07603560a3d0c 100644 (file)
@@ -45,11 +45,20 @@ struct ConstantScalarOpPattern final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts arith.remsi to SPIR-V ops.
+/// Converts arith.remsi to GLSL SPIR-V ops.
 ///
 /// This cannot be merged into the template unary/binary pattern due to Vulkan
 /// restrictions over spv.SRem and spv.SMod.
-struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> {
+struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
+  using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Converts arith.remsi to OpenCL SPIR-V ops.
+struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
 
   LogicalResult
@@ -396,7 +405,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// RemSIOpPattern
+// RemSIOpGLSLPattern
 //===----------------------------------------------------------------------===//
 
 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
@@ -406,6 +415,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
 /// if either operand can be negative. Emulate it via spv.UMod.
+template <typename SignedAbsOp>
 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
                                     Value signOperand, OpBuilder &builder) {
   assert(lhs.getType() == rhs.getType());
@@ -414,8 +424,8 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
   Type type = lhs.getType();
 
   // Calculate the remainder with spv.UMod.
-  Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
-  Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
+  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
+  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
 
   // Fix the sign.
@@ -429,11 +439,26 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
 }
 
 LogicalResult
-RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
-                                ConversionPatternRewriter &rewriter) const {
-  Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0],
-                                        adaptor.getOperands()[1],
-                                        adaptor.getOperands()[0], rewriter);
+RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
+      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+      adaptor.getOperands()[0], rewriter);
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// RemSIOpOCLPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
+                                   ConversionPatternRewriter &rewriter) const {
+  Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
+      op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
+      adaptor.getOperands()[0], rewriter);
   rewriter.replaceOp(op, result);
 
   return success();
@@ -762,7 +787,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
-    RemSIOpPattern,
+    RemSIOpGLSLPattern, RemSIOpOCLPattern,
     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
index 5ab13e780d0d4fcef1cbae6c1ab35ab90898f625..291c7fdd77b1c13ee8cf820986e172049e096b88 100644 (file)
@@ -43,14 +43,8 @@ func @scalar_srem(%lhs: i32, %rhs: i32) {
 // Check float unary operation conversions.
 // CHECK-LABEL: @float32_unary_scalar
 func @float32_unary_scalar(%arg0: f32) {
-  // CHECK: spv.GLSL.FAbs %{{.*}}: f32
-  %0 = math.abs %arg0 : f32
-  // CHECK: spv.GLSL.Ceil %{{.*}}: f32
-  %1 = math.ceil %arg0 : f32
   // CHECK: spv.FNegate %{{.*}}: f32
-  %5 = arith.negf %arg0 : f32
-  // CHECK: spv.GLSL.Floor %{{.*}}: f32
-  %10 = math.floor %arg0 : f32
+  %0 = arith.negf %arg0 : f32
   return
 }
 
@@ -842,3 +836,39 @@ func @sitofp(%arg0 : i64) -> f64 {
 }
 
 } // end module
+
+// -----
+
+// Check OpenCL lowering of arith.remsi
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int16, Kernel], []>, {}>
+} {
+
+// CHECK-LABEL: @scalar_srem
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func @scalar_srem(%lhs: i32, %rhs: i32) {
+  // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : i32
+  // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : i32
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
+  %0 = arith.remsi %lhs, %rhs: i32
+  return
+}
+
+// CHECK-LABEL: @vector_srem
+// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>)
+func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) {
+  // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : vector<3xi16>
+  // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : vector<3xi16>
+  // CHECK:  %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16>
+  // CHECK:  %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16>
+  // CHECK:  %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16>
+  // CHECK:      %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16>
+  %0 = arith.remsi %arg0, %arg1: vector<3xi16>
+  return
+}
+
+} // end module
index ad32a88a876eae104e7eec4ed3a50afd0fe94121..8cae1ca7d94efed24e6d48df7f018fe7f1d9f850 100644 (file)
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
   %6 = math.tanh %arg0 : f32
   // CHECK: spv.GLSL.Sin %{{.*}}: f32
   %7 = math.sin %arg0 : f32
+  // CHECK: spv.GLSL.FAbs %{{.*}}: f32
+  %8 = math.abs %arg0 : f32
+  // CHECK: spv.GLSL.Ceil %{{.*}}: f32
+  %9 = math.ceil %arg0 : f32
+  // CHECK: spv.GLSL.Floor %{{.*}}: f32
+  %10 = math.floor %arg0 : f32
   return
 }
 
index 8a1a3acc5f0cd8930fb9e7742b128a8e2680815a..5bfd4e477c21c8953798010cc208a5487af81663 100644 (file)
@@ -22,6 +22,12 @@ func @float32_unary_scalar(%arg0: f32) {
   %6 = math.tanh %arg0 : f32
   // CHECK: spv.OCL.sin %{{.*}}: f32
   %7 = math.sin %arg0 : f32
+  // CHECK: spv.OCL.fabs %{{.*}}: f32
+  %8 = math.abs %arg0 : f32
+  // CHECK: spv.OCL.ceil %{{.*}}: f32
+  %9 = math.ceil %arg0 : f32
+  // CHECK: spv.OCL.floor %{{.*}}: f32
+  %10 = math.floor %arg0 : f32
   return
 }