From d9edc1a585d7b4a203ee29136260282bc9c65c95 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 19 Jan 2022 10:32:16 -0800 Subject: [PATCH] [mlir][spirv] Add math.fma lowering to spirv Differential Revision: https://reviews.llvm.org/D117704 --- .../ArithmeticToSPIRV/ArithmeticToSPIRV.cpp | 30 ++++++------ mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp | 53 +++++++++++----------- mlir/lib/Conversion/SPIRVCommon/Pattern.h | 7 +-- .../Conversion/StandardToSPIRV/StandardToSPIRV.cpp | 12 ++--- .../Conversion/MathToSPIRV/math-to-glsl-spirv.mlir | 15 ++++++ 5 files changed, 67 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp index 4fd2b6d..b3972a6 100644 --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -790,25 +790,25 @@ void mlir::arith::populateArithmeticToSPIRVPatterns( patterns.add< ConstantCompositeOpPattern, ConstantScalarOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, RemSIOpGLSLPattern, RemSIOpOCLPattern, BitwiseOpPattern, BitwiseOpPattern, XOrIOpLogicalPattern, XOrIOpBooleanPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, TypeCastingOpPattern, ExtUII1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 7e95d33..ec8402a 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -64,35 +64,36 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // GLSL patterns - patterns.add< - Log1pOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern>( - typeConverter, patterns.getContext()); + patterns + .add, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern>( + typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern>( + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h index 3933899..c70009f 100644 --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -15,16 +15,17 @@ namespace mlir { namespace spirv { -/// Converts unary and binary standard operations to SPIR-V operations. +/// Converts elementwise unary, binary and ternary standard operations to SPIR-V +/// operations. template -class UnaryAndBinaryOpPattern final : public OpConversionPattern { +class ElementwiseOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(adaptor.getOperands().size() <= 2); + assert(adaptor.getOperands().size() <= 3); auto dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index fea7c7c..8d39aa9 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, patterns.add< // Unary and binary patterns - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, context); diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir index 8cae1ca..f0e0b7e 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -68,4 +68,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { return } + // CHECK-LABEL: @float32_ternary_scalar +func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) { + // CHECK: spv.GLSL.Fma %{{.*}}: f32 + %0 = math.fma %a, %b, %c : f32 + return +} + +// CHECK-LABEL: @float32_ternary_vector +func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>, + %c: vector<4xf32>) { + // CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32> + %0 = math.fma %a, %b, %c : vector<4xf32> + return +} + } // end module -- 2.7.4