From: Stanley Winata Date: Mon, 19 Sep 2022 17:28:55 +0000 (-0400) Subject: [mlir][spirv] Lower arith max/min ops to OpenCL ones X-Git-Tag: upstream/17.0.6~33130 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7d23d1e640dcde0e90a42353102198b95e20e5f4;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Lower arith max/min ops to OpenCL ones Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D132881 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td index 524c375..d95ed47 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td @@ -332,6 +332,67 @@ def SPV_CLFmaOp : SPV_CLTernaryArithmeticOp<"fma", 26, SPV_Float> { // ----- +def SPV_CLFMaxOp : SPV_CLBinaryArithmeticOp<"fmax", 27, SPV_Float> { + let summary = "Return maximum of two floating-point operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x. If one argument is a NaN, + Fmax returns the other argument. If both arguments are NaNs, Fmax returns a NaN. + + Result Type, x and y must be floating-point or vector(2,3,4,8,16) + of floating-point values. + + All of the operands, including the Result Type operand, + must be of the same type. + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmax-op ::= ssa-id `=` `spv.CL.fmax` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.CL.fmax %0, %1 : f32 + %3 = spv.CL.fmax %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_CLFMinOp : SPV_CLBinaryArithmeticOp<"fmin", 28, SPV_Float> { + let summary = "Return minimum of two floating-point operands"; + + let description = [{ + Returns y if y < x, otherwise it returns x. If one argument is a NaN, Fmin returns the other argument. + If both arguments are NaNs, Fmin returns a NaN. + + Result Type,x and y must be floating-point or vector(2,3,4,8,16) of floating-point values. + + All of the operands, including the Result Type operand, must be of the same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmin-op ::= ssa-id `=` `spv.CL.fmin` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.CL.fmin %0, %1 : f32 + %3 = spv.CL.fmin %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_CLLogOp : SPV_CLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -573,4 +634,110 @@ def SPV_CLSAbsOp : SPV_CLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> { }]; } +// ----- + +def SPV_CLSMaxOp : SPV_CLBinaryArithmeticOp<"s_max", 156, SPV_Integer> { + let summary = "Return maximum of two signed integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smax-op ::= ssa-id `=` `spv.CL.s_max` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.s_max %0, %1 : i32 + %3 = spv.CL.s_max %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_CLUMaxOp : SPV_CLBinaryArithmeticOp<"u_max", 157, SPV_Integer> { + let summary = "Return maximum of two unsigned integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + umax-op ::= ssa-id `=` `spv.CL.u_max` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.u_max %0, %1 : i32 + %3 = spv.CL.u_max %0, %1 : vector<3xi16> + ``` + }]; +} + +def SPV_CLSMinOp : SPV_CLBinaryArithmeticOp<"s_min", 158, SPV_Integer> { + let summary = "Return minimum of two signed integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smin-op ::= ssa-id `=` `spv.CL.s_min` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.s_min %0, %1 : i32 + %3 = spv.CL.s_min %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_CLUMinOp : SPV_CLBinaryArithmeticOp<"u_min", 159, SPV_Integer> { + let summary = "Return minimum of two unsigned integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + umin-op ::= ssa-id `=` `spv.CL.u_min` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.u_min %0, %1 : i32 + %3 = spv.CL.u_min %0, %1 : vector<3xi16> + ``` + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_CL_OPS diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp index 18c3abe..c6c3f2b 100644 --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -219,7 +219,7 @@ public: ConversionPatternRewriter &rewriter) const override; }; -/// Converts arith.maxf to spv.GL.FMax. +/// Converts arith.maxf to spv.GL.FMax or spv.CL.fmax. template class MinMaxFOpPattern final : public OpConversionPattern { public: @@ -926,9 +926,11 @@ LogicalResult MinMaxFOpPattern::matchAndRewrite( // arith.maxf/minf: // "if one of the arguments is NaN, then the result is also NaN." - // spv.GL.FMax/FMin: + // spv.GL.FMax/FMin // "which operand is the result is undefined if one of the operands // is a NaN." + // spv.CL.fmax/fmin: + // "If one argument is a NaN, Fmin returns the other argument." Location loc = op.getLoc(); Value spirvOp = rewriter.create(loc, dstType, adaptor.getOperands()); @@ -998,7 +1000,14 @@ void mlir::arith::populateArithmeticToSPIRVPatterns( spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern + spirv::ElementwiseOpPattern, + + MinMaxFOpPattern, + MinMaxFOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern >(typeConverter, patterns.getContext()); // clang-format on diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir index 9430183..8b4de3d 100644 --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -970,12 +970,80 @@ func.func @sitofp(%arg0 : i64) -> f64 { // ----- -// Check OpenCL lowering of arith.remsi +// Check various lowerings for OpenCL. module attributes { spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { +// Check integer operation conversions. +// CHECK-LABEL: @int32_scalar +func.func @int32_scalar(%lhs: i32, %rhs: i32) { + // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32 + %0 = arith.addi %lhs, %rhs: i32 + // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32 + %1 = arith.subi %lhs, %rhs: i32 + // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32 + %2 = arith.muli %lhs, %rhs: i32 + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 + %3 = arith.divsi %lhs, %rhs: i32 + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 + %4 = arith.divui %lhs, %rhs: i32 + // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 + %5 = arith.remui %lhs, %rhs: i32 + // CHECK: spv.CL.s_max %{{.*}}, %{{.*}}: i32 + %6 = arith.maxsi %lhs, %rhs : i32 + // CHECK: spv.CL.u_max %{{.*}}, %{{.*}}: i32 + %7 = arith.maxui %lhs, %rhs : i32 + // CHECK: spv.CL.s_min %{{.*}}, %{{.*}}: i32 + %8 = arith.minsi %lhs, %rhs : i32 + // CHECK: spv.CL.u_min %{{.*}}, %{{.*}}: i32 + %9 = arith.minui %lhs, %rhs : i32 + return +} + +// Check float binary operation conversions. +// CHECK-LABEL: @float32_binary_scalar +func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) { + // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 + %0 = arith.addf %lhs, %rhs: f32 + // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32 + %1 = arith.subf %lhs, %rhs: f32 + // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 + %2 = arith.mulf %lhs, %rhs: f32 + // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32 + %3 = arith.divf %lhs, %rhs: f32 + // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 + %4 = arith.remf %lhs, %rhs: f32 + return +} + +// CHECK-LABEL: @float32_minf_scalar +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[MIN:.+]] = spv.CL.fmin %arg0, %arg1 : f32 + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : f32 + // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : f32 + // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MIN]] + // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]] + %0 = arith.minf %arg0, %arg1 : f32 + // CHECK: return %[[SELECT2]] + return %0: f32 +} + +// CHECK-LABEL: @float32_maxf_scalar +// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32> +func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { + // CHECK: %[[MAX:.+]] = spv.CL.fmax %arg0, %arg1 : vector<2xf32> + // CHECK: %[[LHS_NAN:.+]] = spv.IsNan %[[LHS]] : vector<2xf32> + // CHECK: %[[RHS_NAN:.+]] = spv.IsNan %[[RHS]] : vector<2xf32> + // CHECK: %[[SELECT1:.+]] = spv.Select %[[LHS_NAN]], %[[LHS]], %[[MAX]] + // CHECK: %[[SELECT2:.+]] = spv.Select %[[RHS_NAN]], %[[RHS]], %[[SELECT1]] + %0 = arith.maxf %arg0, %arg1 : vector<2xf32> + // CHECK: return %[[SELECT2]] + return %0: vector<2xf32> +} + // CHECK-LABEL: @scalar_srem // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) func.func @scalar_srem(%lhs: i32, %rhs: i32) { diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index d4eba2c..c4d5478 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -185,3 +185,45 @@ func.func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () %2 = spv.CL.fma %a, %b, %c : vector<3xf32> return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.CL.{F|S|U}{Max|Min} +//===----------------------------------------------------------------------===// + +func.func @fmaxmin(%arg0 : f32, %arg1 : f32) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f32 + %1 = spv.CL.fmax %arg0, %arg1 : f32 + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f32 + %2 = spv.CL.fmin %arg0, %arg1 : f32 + return +} + +func.func @fmaxminvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : vector<3xf16> + %1 = spv.CL.fmax %arg0, %arg1 : vector<3xf16> + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : vector<3xf16> + %2 = spv.CL.fmin %arg0, %arg1 : vector<3xf16> + return +} + +func.func @fmaxminf64(%arg0 : f64, %arg1 : f64) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f64 + %1 = spv.CL.fmax %arg0, %arg1 : f64 + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f64 + %2 = spv.CL.fmin %arg0, %arg1 : f64 + return +} + +func.func @iminmax(%arg0: i32, %arg1: i32) { + // CHECK: spv.CL.s_max {{%.*}}, {{%.*}} : i32 + %1 = spv.CL.s_max %arg0, %arg1 : i32 + // CHECK: spv.CL.u_max {{%.*}}, {{%.*}} : i32 + %2 = spv.CL.u_max %arg0, %arg1 : i32 + // CHECK: spv.CL.s_min {{%.*}}, {{%.*}} : i32 + %3 = spv.CL.s_min %arg0, %arg1 : i32 + // CHECK: spv.CL.u_min {{%.*}}, {{%.*}} : i32 + %4 = spv.CL.u_min %arg0, %arg1 : i32 + return +} diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir index 784f563..d5934f0 100644 --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -44,4 +44,21 @@ spv.module Physical64 OpenCL requires #spv.vce { %13 = spv.CL.fma %arg0, %arg1, %arg2 : f32 spv.Return } + + spv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" { + // CHECK: {{%.*}} = spv.CL.fmax {{%.*}}, {{%.*}} : f32 + %1 = spv.CL.fmax %arg0, %arg1 : f32 + // CHECK: {{%.*}} = spv.CL.s_max {{%.*}}, {{%.*}} : i32 + %2 = spv.CL.s_max %arg2, %arg3 : i32 + // CHECK: {{%.*}} = spv.CL.u_max {{%.*}}, {{%.*}} : i32 + %3 = spv.CL.u_max %arg2, %arg3 : i32 + + // CHECK: {{%.*}} = spv.CL.fmin {{%.*}}, {{%.*}} : f32 + %4 = spv.CL.fmin %arg0, %arg1 : f32 + // CHECK: {{%.*}} = spv.CL.s_min {{%.*}}, {{%.*}} : i32 + %5 = spv.CL.s_min %arg2, %arg3 : i32 + // CHECK: {{%.*}} = spv.CL.u_min {{%.*}}, {{%.*}} : i32 + %6 = spv.CL.u_min %arg2, %arg3 : i32 + spv.Return + } }