#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "arith-to-spirv-pattern"
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts arith.addi_carry to spv.IAddCarry.
+class AddICarryOpPattern final : public OpConversionPattern<arith::AddICarryOp> {
+public:
+ using OpConversionPattern<arith::AddICarryOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts arith.select to spv.Select.
class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
public:
}
//===----------------------------------------------------------------------===//
+// AddICarryOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AddICarryOpPattern::matchAndRewrite(arith::AddICarryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Type dstElemTy = adaptor.getLhs().getType();
+ auto resultTy = spirv::StructType::get({dstElemTy, dstElemTy});
+
+ Location loc = op->getLoc();
+ Value result = rewriter.create<spirv::IAddCarryOp>(
+ loc, resultTy, adaptor.getLhs(), adaptor.getRhs());
+
+ Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::makeArrayRef(0));
+ Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
+ loc, result, llvm::makeArrayRef(1));
+
+ // Convert the carry value to boolean.
+ Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
+ Value carryResult =
+ rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
+
+ rewriter.replaceOp(op, {sumResult, carryResult});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// SelectOpPattern
//===----------------------------------------------------------------------===//
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
- SelectOpPattern,
+ AddICarryOpPattern, SelectOpPattern,
spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
return
}
+// Check integer add-with-carry conversions.
+// CHECK-LABEL: @int32_scalar_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_addi_carry(%lhs: i32, %rhs: i32) -> (i32, i1) {
+ // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(i32, i32)>
+ // CHECK-DAG: %[[ONE:.+]] = spv.Constant 1 : i32
+ // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : i32
+ // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1
+ %sum, %carry = arith.addi_carry %lhs, %rhs: i32, i1
+ return %sum, %carry : i32, i1
+}
+
+// CHECK-LABEL: @int32_vector_addi_carry
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_addi_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
+ // CHECK-NEXT: %[[IAC:.+]] = spv.IAddCarry %[[LHS]], %[[RHS]] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[SUM:.+]] = spv.CompositeExtract %[[IAC]][0 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[C0:.+]] = spv.CompositeExtract %[[IAC]][1 : i32] : !spv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[ONE:.+]] = spv.Constant dense<1> : vector<4xi32>
+ // CHECK-NEXT: %[[C1:.+]] = spv.IEqual %[[C0]], %[[ONE]] : vector<4xi32>
+ // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1>
+ %sum, %carry = arith.addi_carry %lhs, %rhs: vector<4xi32>, vector<4xi1>
+ return %sum, %carry : vector<4xi32>, vector<4xi1>
+}
+
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {