From 95c4e518393cbb0d6ed2c615c08347960995c48a Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 17 Aug 2022 21:32:00 -0400 Subject: [PATCH] [mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D131908 --- .../ArithmeticToSPIRV/ArithmeticToSPIRV.cpp | 42 +++++++++++++++++++++- .../ArithmeticToSPIRV/arithmetic-to-spirv.mlir | 27 ++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp index 52ab62c..56a241c 100644 --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -13,8 +13,11 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arith-to-spirv-pattern" @@ -192,6 +195,15 @@ public: ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.addi_carry to spv.IAddCarry. +class AddICarryOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts arith.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: @@ -834,6 +846,34 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite( } //===----------------------------------------------------------------------===// +// AddICarryOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type dstElemTy = adaptor.getLhs().getType(); + auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy}); + + Location loc = op->getLoc(); + Value result = rewriter.create( + loc, resultTy, adaptor.getLhs(), adaptor.getRhs()); + + Value sumResult = rewriter.create( + loc, result, llvm::makeArrayRef(0)); + Value carryValue = rewriter.create( + loc, result, llvm::makeArrayRef(1)); + + // Convert the carry value to boolean. + Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); + Value carryResult = + rewriter.create(loc, carryValue, one); + + rewriter.replaceOp(op, {sumResult, carryResult}); + return success(); +} + +//===----------------------------------------------------------------------===// // SelectOpPattern //===----------------------------------------------------------------------===// @@ -887,7 +927,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns( TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - SelectOpPattern, + AddICarryOpPattern, SelectOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir index 6b8cba2..ca48648 100644 --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -72,6 +72,33 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) { return } +// Check integer add-with-carry conversions. +// CHECK-LABEL: @int32_scalar_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32 + // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1 + %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1 + return %sum, %carry : i32, i1 +} + +// CHECK-LABEL: @int32_vector_addi_carry +// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) +func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { + // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)> + // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32> + // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32> + // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> + return %sum, %carry : vector<4xi32>, vector<4xi1> +} + // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) { -- 2.7.4