From 28246b7e759708e8e667cadef11b6a516c258dc6 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 7 Dec 2022 17:15:55 -0500 Subject: [PATCH] [mlir][arith] Rename addui_carry to addui_extended The goal is to make the naming of the future `_extended` ops more consistent. With unsigned addition, the carry value/flag and overflow bit are the same, but this is not true when it comes to signed addition. Also rename the second result from `carry` to `overflow`. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D139569 --- mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 25 +++++----- mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 24 ++++----- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 18 +++---- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 57 +++++++++++----------- .../Dialect/Arith/Transforms/EmulateWideInt.cpp | 9 ++-- .../test/Conversion/ArithToLLVM/arith-to-llvm.mlir | 12 ++--- .../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 16 +++--- mlir/test/Dialect/Arith/canonicalize.mlir | 16 +++--- mlir/test/Dialect/Arith/emulate-wide-int.mlir | 4 +- mlir/test/Dialect/Arith/invalid.mlir | 16 +++--- mlir/test/Dialect/Arith/ops.mlir | 24 ++++----- .../Dialect/Arith/test-emulate-wide-int-pass.mlir | 2 +- 12 files changed, 114 insertions(+), 109 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index cc1801b..6c7244b 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -222,33 +222,36 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> { } -def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Pure, Commutative, +def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative, AllTypesMatch<["lhs", "rhs", "sum"]>]> { - let summary = "unsigned integer addition operation returning sum and carry"; + let summary = [{ + extended unsigned integer addition operation returning sum and overflow bit + }]; + let description = [{ - The `addui_carry` operation takes two operands and returns two results: the - sum (same type as both operands), and the carry (boolean-like). The carry - value `1` indicates unsigned addition overflow, while indicates `0` no - overflow. + Performs (N+1)-bit addition on zero-extended operands. Returns two results: + the N-bit sum (same type as both operands), and the overflow bit + (boolean-like), where`1` indicates unsigned addition overflow, while `0` + indicates no overflow. Example: ```mlir // Scalar addition. - %sum, %carry = arith.addui_carry %b, %c : i64, i1 + %sum, %overflow = arith.addui_extended %b, %c : i64, i1 // Vector element-wise addition. - %b:2 = arith.addui_carry %g, %h : vector<4xi32>, vector<4xi1> + %d:2 = arith.addui_extended %e, %f : vector<4xi32>, vector<4xi1> // Tensor element-wise addition. - %c:2 = arith.addui_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1> + %x:2 = arith.addui_extended %y, %z : tensor<4x?xi8>, tensor<4x?xi1> ``` }]; let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs); - let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry); + let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow); let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry) + $lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow) }]; let builders = [ diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 3ad0155..0289bea 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -133,12 +133,12 @@ using IndexCastOpSILowering = using IndexCastOpUILowering = IndexCastOpLowering; -struct AddUICarryOpLowering - : public ConvertOpToLLVMPattern { +struct AddUIExtendedOpLowering + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -223,15 +223,15 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( } //===----------------------------------------------------------------------===// -// AddUICarryOpLowering +// AddUIExtendedOpLowering //===----------------------------------------------------------------------===// -LogicalResult AddUICarryOpLowering::matchAndRewrite( - arith::AddUICarryOp op, OpAdaptor adaptor, +LogicalResult AddUIExtendedOpLowering::matchAndRewrite( + arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type sumResultType = op.getSum().getType(); - Type carryResultType = op.getCarry().getType(); + Type overflowResultType = op.getOverflow().getType(); if (!LLVM::isCompatibleType(operandType)) return failure(); @@ -241,16 +241,16 @@ LogicalResult AddUICarryOpLowering::matchAndRewrite( // Handle the scalar and 1D vector cases. if (!operandType.isa()) { - Type newCarryType = typeConverter->convertType(carryResultType); + Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = - LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newCarryType}); + LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); Value addOverflow = rewriter.create( loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = rewriter.create(loc, addOverflow, 0); - Value carryExtracted = + Value overflowExtracted = rewriter.create(loc, addOverflow, 1); - rewriter.replaceOp(op, {sumExtracted, carryExtracted}); + rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -374,7 +374,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( AddFOpLowering, AddIOpLowering, AndIOpLowering, - AddUICarryOpLowering, + AddUIExtendedOpLowering, BitcastOpLowering, ConstantOpLowering, CmpFOpLowering, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d550e0e..a127dd8 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -213,13 +213,13 @@ public: ConversionPatternRewriter &rewriter) const override; }; -/// Converts arith.addui_carry to spirv.IAddCarry. -class AddICarryOpPattern final - : public OpConversionPattern { +/// Converts arith.addui_extended to spirv.IAddCarry. +class AddUIExtendedOpPattern final + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, + matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -920,12 +920,12 @@ LogicalResult CmpFOpNanNonePattern::matchAndRewrite( } //===----------------------------------------------------------------------===// -// AddICarryOpPattern +// AddUIExtendedOpPattern //===----------------------------------------------------------------------===// -LogicalResult -AddICarryOpPattern::matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { +LogicalResult AddUIExtendedOpPattern::matchAndRewrite( + arith::AddUIExtendedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); Value result = rewriter.create(loc, adaptor.getLhs(), @@ -1040,7 +1040,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, - AddICarryOpPattern, SelectOpPattern, + AddUIExtendedOpPattern, SelectOpPattern, MinMaxFOpPattern, MinMaxFOpPattern, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 00e2396..0a2a8a9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -219,75 +219,76 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, } //===----------------------------------------------------------------------===// -// AddUICarryOp +// AddUIExtendedOp //===----------------------------------------------------------------------===// -Optional> arith::AddUICarryOp::getShapeForUnroll() { +Optional> arith::AddUIExtendedOp::getShapeForUnroll() { if (auto vt = getType(0).dyn_cast()) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } -// Returns the carry bit, assuming that `sum` is the result of addition of -// `operand` and another number. -static APInt calculateCarry(const APInt &sum, const APInt &operand) { +// Returns the overflow bit, assuming that `sum` is the result of unsigned +// addition of `operand` and another number. +static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); } LogicalResult -arith::AddUICarryOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - auto carryTy = getCarry().getType(); - // addui_carry(x, 0) -> x, false +arith::AddUIExtendedOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + auto overflowTy = getOverflow().getType(); + // addui_extended(x, 0) -> x, false if (matchPattern(getRhs(), m_Zero())) { - auto carryZero = APInt::getZero(1); + auto overflowZero = APInt::getZero(1); Builder builder(getContext()); - auto falseValue = builder.getZeroAttr(carryTy); + auto falseValue = builder.getZeroAttr(overflowTy); results.push_back(getLhs()); results.push_back(falseValue); return success(); } - // addui_carry(constant_a, constant_b) -> constant_sum, constant_carry + // addui_overflow(constant_a, constant_b) -> constant_sum, constant_carry // Let the `constFoldBinaryOp` utility attempt to fold the sum of both - // operands. If that succeeds, calculate the carry boolean based on the sum + // operands. If that succeeds, calculate the overflow bit based on the sum // and the first (constant) operand, `lhs`. Note that we cannot simply call - // `constFoldBinaryOp` again to calculate the carry (bit) because the + // `constFoldBinaryOp` again to calculate the overflow bit because the // constructed attribute is of the same element type as both operands. if (Attribute sumAttr = constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) { - Attribute carryAttr; + Attribute overflowAttr; if (auto lhs = operands[0].dyn_cast()) { - // Both arguments are scalars, calculate the scalar carry value. + // Both arguments are scalars, calculate the scalar overflow value. auto sum = sumAttr.cast(); - carryAttr = IntegerAttr::get( - carryTy, calculateCarry(sum.getValue(), lhs.getValue())); + overflowAttr = IntegerAttr::get( + overflowTy, + calculateUnsignedOverflow(sum.getValue(), lhs.getValue())); } else if (auto lhs = operands[0].dyn_cast()) { - // Both arguments are splats, calculate the splat carry value. + // Both arguments are splats, calculate the splat overflow value. auto sum = sumAttr.cast(); - APInt carry = calculateCarry(sum.getSplatValue(), - lhs.getSplatValue()); - carryAttr = SplatElementsAttr::get(carryTy, carry); + APInt overflow = calculateUnsignedOverflow(sum.getSplatValue(), + lhs.getSplatValue()); + overflowAttr = SplatElementsAttr::get(overflowTy, overflow); } else if (auto lhs = operands[0].dyn_cast()) { - // Othwerwise calculate element-wise carry values. + // Othwerwise calculate element-wise overflow values. auto sum = sumAttr.cast(); const auto numElems = static_cast(sum.getNumElements()); - SmallVector carryValues; - carryValues.reserve(numElems); + SmallVector overflowValues; + overflowValues.reserve(numElems); auto sumIt = sum.value_begin(); auto lhsIt = lhs.value_begin(); for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt) - carryValues.push_back(calculateCarry(*sumIt, *lhsIt)); + overflowValues.push_back(calculateUnsignedOverflow(*sumIt, *lhsIt)); - carryAttr = DenseElementsAttr::get(carryTy, carryValues); + overflowAttr = DenseElementsAttr::get(overflowTy, overflowValues); } else { return failure(); } results.push_back(sumAttr); - results.push_back(carryAttr); + results.push_back(overflowAttr); return success(); } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index 28134cf..f10fefb 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -276,11 +276,12 @@ struct ConvertAddI final : OpConversionPattern { auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - auto lowSum = rewriter.create(loc, lhsElem0, rhsElem0); - Value carryVal = - rewriter.create(loc, newElemTy, lowSum.getCarry()); + auto lowSum = + rewriter.create(loc, lhsElem0, rhsElem0); + Value overflowVal = + rewriter.create(loc, newElemTy, lowSum.getOverflow()); - Value high0 = rewriter.create(loc, carryVal, lhsElem1); + Value high0 = rewriter.create(loc, overflowVal, lhsElem1); Value high = rewriter.create(loc, high0, rhsElem1); Value resultVec = diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index d8e49a5..cf207c2 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -355,24 +355,24 @@ func.func @bitcast_1d(%arg0: vector<2xf32>) { // ----- -// CHECK-LABEL: @addui_carry_scalar +// CHECK-LABEL: @addui_extended_scalar // CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i1) -func.func @addui_carry_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) { +func.func @addui_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i1) { // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (i32, i32) -> !llvm.struct<(i32, i1)> // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(i32, i1)> // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(i32, i1)> - %sum, %carry = arith.addui_carry %arg0, %arg1 : i32, i1 + %sum, %carry = arith.addui_extended %arg0, %arg1 : i32, i1 // CHECK-NEXT: return [[SUM]], [[CARRY]] : i32, i1 return %sum, %carry : i32, i1 } -// CHECK-LABEL: @addui_carry_vector1d +// CHECK-LABEL: @addui_extended_vector1d // CHECK-SAME: ([[ARG0:%.+]]: vector<3xi16>, [[ARG1:%.+]]: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) -func.func @addui_carry_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) { +func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -> (vector<3xi16>, vector<3xi1>) { // CHECK-NEXT: [[RES:%.+]] = "llvm.intr.uadd.with.overflow"([[ARG0]], [[ARG1]]) : (vector<3xi16>, vector<3xi16>) -> !llvm.struct<(vector<3xi16>, vector<3xi1>)> // CHECK-NEXT: [[SUM:%.+]] = llvm.extractvalue [[RES]][0] : !llvm.struct<(vector<3xi16>, vector<3xi1>)> // CHECK-NEXT: [[CARRY:%.+]] = llvm.extractvalue [[RES]][1] : !llvm.struct<(vector<3xi16>, vector<3xi1>)> - %sum, %carry = arith.addui_carry %arg0, %arg1 : vector<3xi16>, vector<3xi1> + %sum, %carry = arith.addui_extended %arg0, %arg1 : vector<3xi16>, vector<3xi1> // CHECK-NEXT: return [[SUM]], [[CARRY]] : vector<3xi16>, vector<3xi1> return %sum, %carry : vector<3xi16>, vector<3xi1> } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index beb52c5..938bafa 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -73,30 +73,30 @@ func.func @index_scalar_srem(%lhs: index, %rhs: index) { } // Check integer add-with-carry conversions. -// CHECK-LABEL: @int32_scalar_addui_carry +// CHECK-LABEL: @int32_scalar_addui_extended // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) -func.func @int32_scalar_addui_carry(%lhs: i32, %rhs: i32) -> (i32, i1) { +func.func @int32_scalar_addui_extended(%lhs: i32, %rhs: i32) -> (i32, i1) { // CHECK-NEXT: %[[IAC:.+]] = spirv.IAddCarry %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)> // CHECK-DAG: %[[SUM:.+]] = spirv.CompositeExtract %[[IAC]][0 : i32] : !spirv.struct<(i32, i32)> // CHECK-DAG: %[[C0:.+]] = spirv.CompositeExtract %[[IAC]][1 : i32] : !spirv.struct<(i32, i32)> // CHECK-DAG: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK-NEXT: %[[C1:.+]] = spirv.IEqual %[[C0]], %[[ONE]] : i32 // CHECK-NEXT: return %[[SUM]], %[[C1]] : i32, i1 - %sum, %carry = arith.addui_carry %lhs, %rhs: i32, i1 - return %sum, %carry : i32, i1 + %sum, %overflow = arith.addui_extended %lhs, %rhs: i32, i1 + return %sum, %overflow : i32, i1 } -// CHECK-LABEL: @int32_vector_addui_carry +// CHECK-LABEL: @int32_vector_addui_extended // CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>) -func.func @int32_vector_addui_carry(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { +func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { // CHECK-NEXT: %[[IAC:.+]] = spirv.IAddCarry %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[SUM:.+]] = spirv.CompositeExtract %[[IAC]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[C0:.+]] = spirv.CompositeExtract %[[IAC]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)> // CHECK-DAG: %[[ONE:.+]] = spirv.Constant dense<1> : vector<4xi32> // CHECK-NEXT: %[[C1:.+]] = spirv.IEqual %[[C0]], %[[ONE]] : vector<4xi32> // CHECK-NEXT: return %[[SUM]], %[[C1]] : vector<4xi32>, vector<4xi1> - %sum, %carry = arith.addui_carry %lhs, %rhs: vector<4xi32>, vector<4xi1> - return %sum, %carry : vector<4xi32>, vector<4xi1> + %sum, %overflow = arith.addui_extended %lhs, %rhs: vector<4xi32>, vector<4xi1> + return %sum, %overflow : vector<4xi32>, vector<4xi1> } // Check float unary operation conversions. diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index d2439a2..8b41aad 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -640,7 +640,7 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index { // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) { %zero = arith.constant 0 : i32 - %sum, %carry = arith.addui_carry %arg0, %zero: i32, i1 + %sum, %carry = arith.addui_extended %arg0, %zero: i32, i1 return %sum, %carry : i32, i1 } @@ -649,7 +649,7 @@ func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) { // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) { %zero = arith.constant dense<0> : vector<4xi32> - %sum, %carry = arith.addui_carry %arg0, %zero: vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_extended %arg0, %zero: vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } @@ -658,7 +658,7 @@ func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector // CHECK-NEXT: return %arg0, %[[false]] func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) { %zero = arith.constant 0 : i32 - %sum, %carry = arith.addui_carry %zero, %arg0: i32, i1 + %sum, %carry = arith.addui_extended %zero, %arg0: i32, i1 return %sum, %carry : i32, i1 } @@ -669,7 +669,7 @@ func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) { func.func @addiCarryConstants() -> (i32, i1) { %c13 = arith.constant 13 : i32 %c37 = arith.constant 37 : i32 - %sum, %carry = arith.addui_carry %c13, %c37: i32, i1 + %sum, %carry = arith.addui_extended %c13, %c37: i32, i1 return %sum, %carry : i32, i1 } @@ -680,7 +680,7 @@ func.func @addiCarryConstants() -> (i32, i1) { func.func @addiCarryConstantsOverflow1() -> (i32, i1) { %max = arith.constant 4294967295 : i32 %c1 = arith.constant 1 : i32 - %sum, %carry = arith.addui_carry %max, %c1: i32, i1 + %sum, %carry = arith.addui_extended %max, %c1: i32, i1 return %sum, %carry : i32, i1 } @@ -690,7 +690,7 @@ func.func @addiCarryConstantsOverflow1() -> (i32, i1) { // CHECK-NEXT: return %[[c_2]], %[[true]] func.func @addiCarryConstantsOverflow2() -> (i32, i1) { %max = arith.constant 4294967295 : i32 - %sum, %carry = arith.addui_carry %max, %max: i32, i1 + %sum, %carry = arith.addui_extended %max, %max: i32, i1 return %sum, %carry : i32, i1 } @@ -701,7 +701,7 @@ func.func @addiCarryConstantsOverflow2() -> (i32, i1) { func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) { %v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32> %v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32> - %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_extended %v1, %v2 : vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } @@ -712,7 +712,7 @@ func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) { func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) { %v1 = arith.constant dense<1> : vector<4xi32> %v2 = arith.constant dense<2> : vector<4xi32> - %sum, %carry = arith.addui_carry %v1, %v2 : vector<4xi32>, vector<4xi1> + %sum, %carry = arith.addui_extended %v1, %v2 : vector<4xi32>, vector<4xi1> return %sum, %carry : vector<4xi32>, vector<4xi1> } diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir index 0f85e7a..ab47a56 100644 --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -100,7 +100,7 @@ func.func @constant_vector() -> vector<3xi64> { // CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32> // CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32> // CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32> -// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1 +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_extended [[LOW0]], [[LOW1]] : i32, i1 // CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : i1 to i32 // CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : i32 // CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : i32 @@ -118,7 +118,7 @@ func.func @addi_scalar_a_b(%a : i64, %b : i64) -> i64 { // CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> // CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> // CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32> -// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_carry [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1> +// CHECK-NEXT: [[SUM_L:%.+]], [[CB:%.+]] = arith.addui_extended [[LOW0]], [[LOW1]] : vector<4x1xi32>, vector<4x1xi1> // CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[CB]] : vector<4x1xi1> to vector<4x1xi32> // CHECK-NEXT: [[SUM_H0:%.+]] = arith.addi [[CARRY]], [[HIGH0]] : vector<4x1xi32> // CHECK-NEXT: [[SUM_H1:%.+]] = arith.addi [[SUM_H0]], [[HIGH1]] : vector<4x1xi32> diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir index 9330756..729c865 100644 --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -111,32 +111,32 @@ func.func @func_with_ops(f32) { // ----- func.func @func_with_ops(%a: f32) { - // expected-error@+1 {{'arith.addui_carry' op operand #0 must be signless-integer-like}} - %r:2 = arith.addui_carry %a, %a : f32, i32 + // expected-error@+1 {{'arith.addui_extended' op operand #0 must be signless-integer-like}} + %r:2 = arith.addui_extended %a, %a : f32, i32 return } // ----- func.func @func_with_ops(%a: i32) { - // expected-error@+1 {{'arith.addui_carry' op result #1 must be bool-like}} - %r:2 = arith.addui_carry %a, %a : i32, i32 + // expected-error@+1 {{'arith.addui_extended' op result #1 must be bool-like}} + %r:2 = arith.addui_extended %a, %a : i32, i32 return } // ----- func.func @func_with_ops(%a: vector<8xi32>) { - // expected-error@+1 {{'arith.addui_carry' op if an operand is non-scalar, then all results must be non-scalar}} - %r:2 = arith.addui_carry %a, %a : vector<8xi32>, i1 + // expected-error@+1 {{'arith.addui_extended' op if an operand is non-scalar, then all results must be non-scalar}} + %r:2 = arith.addui_extended %a, %a : vector<8xi32>, i1 return } // ----- func.func @func_with_ops(%a: vector<8xi32>) { - // expected-error@+1 {{'arith.addui_carry' op all non-scalar operands/results must have the same shape and base type}} - %r:2 = arith.addui_carry %a, %a : vector<8xi32>, tensor<8xi1> + // expected-error@+1 {{'arith.addui_extended' op all non-scalar operands/results must have the same shape and base type}} + %r:2 = arith.addui_extended %a, %a : vector<8xi32>, tensor<8xi1> return } diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 9d5c686..99a777d 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -25,27 +25,27 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8] return %0 : vector<[8]xi64> } -// CHECK-LABEL: test_addui_carry -func.func @test_addui_carry(%arg0 : i64, %arg1 : i64) -> i64 { - %sum, %carry = arith.addui_carry %arg0, %arg1 : i64, i1 +// CHECK-LABEL: test_addui_extended +func.func @test_addui_extended(%arg0 : i64, %arg1 : i64) -> i64 { + %sum, %overflow = arith.addui_extended %arg0, %arg1 : i64, i1 return %sum : i64 } -// CHECK-LABEL: test_addui_carry_tensor -func.func @test_addui_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { - %sum, %carry = arith.addui_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> +// CHECK-LABEL: test_addui_extended_tensor +func.func @test_addui_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> { + %sum, %overflow = arith.addui_extended %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1> return %sum : tensor<8x8xi64> } -// CHECK-LABEL: test_addui_carry_vector -func.func @test_addui_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { - %0:2 = arith.addui_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1> +// CHECK-LABEL: test_addui_extended_vector +func.func @test_addui_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> { + %0:2 = arith.addui_extended %arg0, %arg1 : vector<8xi64>, vector<8xi1> return %0#0 : vector<8xi64> } -// CHECK-LABEL: test_addui_carry_scalable_vector -func.func @test_addui_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { - %0:2 = arith.addui_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> +// CHECK-LABEL: test_addui_extended_scalable_vector +func.func @test_addui_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0:2 = arith.addui_extended %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1> return %0#0 : vector<[8]xi64> } diff --git a/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir b/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir index bc6151e..9e14fff 100644 --- a/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir +++ b/mlir/test/Dialect/Arith/test-emulate-wide-int-pass.mlir @@ -21,7 +21,7 @@ func.func @entry() { // CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32> // CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[BCAST0]][0] : vector<2xi32> // CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[BCAST0]][1] : vector<2xi32> -// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_carry [[LOW0]], [[LOW1]] : i32, i1 +// CHECK-NEXT: {{%.+}}, {{%.+}} = arith.addui_extended [[LOW0]], [[LOW1]] : i32, i1 // CHECK: [[RES:%.+]] = llvm.bitcast {{%.+}} : vector<2xi32> to i64 // CHECK-NEXt: return [[RES]] : i64 func.func @emulate_me_please(%x : i64) -> i64 { -- 2.7.4