[mlir][spirv] Add arith.addi_carry to spv.IAddCarry conversion
authorJakub Kuderski <kubak@google.com>
Thu, 18 Aug 2022 01:32:00 +0000 (21:32 -0400)
committerJakub Kuderski <kubak@google.com>
Thu, 18 Aug 2022 01:33:34 +0000 (21:33 -0400)
Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D131908

mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir

index 52ab62c..56a241c 100644 (file)
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "arith-to-spirv-pattern"
@@ -192,6 +195,15 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts arith.addi_carry to spv.IAddCarry.
+class AddICarryOpPattern final : public OpConversionPattern<arith::AddICarryOp> {
+public:
+  using OpConversionPattern<arith::AddICarryOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts arith.select to spv.Select.
 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
 public:
@@ -834,6 +846,34 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
+// AddICarryOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  Type dstElemTy = adaptor.getLhs().getType();
+  auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy});
+
+  Location loc = op->getLoc();
+  Value result = rewriter.create<spirv::IAddCarryOp>(
+      loc, resultTy, adaptor.getLhs(), adaptor.getRhs());
+
+  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+      loc, result, llvm::makeArrayRef(0));
+  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+      loc, result, llvm::makeArrayRef(1));
+
+  // Convert the carry value to boolean.
+  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+  Value carryResult =
+      rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+  rewriter.replaceOp(op, {sumResult, carryResult});
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // SelectOpPattern
 //===----------------------------------------------------------------------===//
 
@@ -887,7 +927,7 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
-    SelectOpPattern,
+    AddICarryOpPattern, SelectOpPattern,
 
     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
index 6b8cba2..ca48648 100644 (file)
@@ -72,6 +72,33 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) {
   return
 }
 
+// Check integer add-with-carry conversions.
+// CHECK-LABEL: @int32_scalar_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+  // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : i32
+  // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
+  %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1
+  return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @int32_vector_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+  // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[C0:.+]]  = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32>
+  // CHECK-NEXT: %[[C1:.+]]  = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
+  // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
+  %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
+  return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
 // Check float unary operation conversions.
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {