return VectorType::get(newShape, type.getElementType());
}
+// Returns a constant of integer of vector type filled with (repeated) `value`.
+static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
+ Location loc, Type type,
+ const APInt &value) {
+ Attribute attr;
+ if (auto intTy = type.dyn_cast<IntegerType>()) {
+ attr = rewriter.getIntegerAttr(type, value);
+ } else {
+ auto vecTy = type.cast<VectorType>();
+ attr = SplatElementsAttr::get(vecTy, value);
+ }
+
+ return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
+// Returns a constant of integer of vector type filled with (repeated) `value`.
+static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
+ Location loc, Type type,
+ int64_t value) {
+ unsigned elementBitWidth = 0;
+ if (auto intTy = type.dyn_cast<IntegerType>())
+ elementBitWidth = intTy.getWidth();
+ else
+ elementBitWidth = type.cast<VectorType>().getElementTypeBitWidth();
+
+ return createScalarOrSplatConstant(rewriter, loc, type,
+ APInt(elementBitWidth, value));
+}
+
// Extracts the `input` vector slice with elements at the last dimension offset
// by `lastOffset`. Returns a value of vector type with the last dimension
// reduced to x1 or fully scalarized, e.g.:
assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
"Wrong number of result components");
- Value resultVec =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
+ Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
for (auto [i, component] : llvm::enumerate(resultComponents))
resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
-
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
auto newTy = getTypeConverter()
->convertType(op.getType())
.dyn_cast_or_null<VectorType>();
Type newElemTy = reduceInnermostDim(newTy);
- auto [lhsElem0, lhsElem1] = extractLastDimHalves(rewriter, loc, lhs);
- auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, rhs);
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ auto [rhsElem0, rhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getRhs());
auto lowSum = rewriter.create<arith::AddUICarryOp>(loc, lhsElem0, rhsElem0);
Value carryVal =
};
//===----------------------------------------------------------------------===//
+// ConvertMulI
+//===----------------------------------------------------------------------===//
+
+struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto newTy = getTypeConverter()
+ ->convertType(op.getType())
+ .dyn_cast_or_null<VectorType>();
+ if (!newTy)
+ return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+
+ Type newElemTy = reduceInnermostDim(newTy);
+ unsigned newBitWidth = newTy.getElementTypeBitWidth();
+ unsigned digitBitWidth = newBitWidth / 2;
+
+ auto [lhsElem0, lhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getLhs());
+ auto [rhsElem0, rhsElem1] =
+ extractLastDimHalves(rewriter, loc, adaptor.getRhs());
+
+ // Emulate multiplication by splitting each input element of type i2N into 4
+ // digits of type iN and bit width i(N/2). This is so that the intermediate
+ // multiplications and additions do not overflow. We extract these i(N/2)
+ // digits from iN vector elements by masking (low digit) and shifting right
+ // (high digit).
+ //
+ // The multiplication algorithm used is the standard (long) multiplication.
+ // Multiplying two i2N integers produces (at most) a i4N result, but because
+ // the calculation of top i2N is not necessary, we omit it.
+ // In total, this implementations performs 10 intermediate multiplications
+ // and 16 additions. The number of multiplications could be decreased by
+ // switching to a more efficient algorithm like Karatsuba. This would,
+ // however, require being able to perform (intermediate) wide additions and
+ // subtractions, so it is not clear that such implementation would be more
+ // efficient.
+
+ APInt lowMaskVal(newBitWidth, 1);
+ lowMaskVal = lowMaskVal.shl(digitBitWidth) - 1;
+ Value lowMask =
+ createScalarOrSplatConstant(rewriter, loc, newElemTy, lowMaskVal);
+ auto getLowDigit = [lowMask, newElemTy, loc, &rewriter](Value v) {
+ return rewriter.create<arith::AndIOp>(loc, newElemTy, v, lowMask);
+ };
+
+ Value shiftVal =
+ createScalarOrSplatConstant(rewriter, loc, newElemTy, digitBitWidth);
+ auto getHighDigit = [shiftVal, loc, &rewriter](Value v) {
+ return rewriter.create<arith::ShRUIOp>(loc, v, shiftVal);
+ };
+
+ Value zeroDigit = createScalarOrSplatConstant(rewriter, loc, newElemTy, 0);
+ std::array<Value, 4> resultDigits = {zeroDigit, zeroDigit, zeroDigit,
+ zeroDigit};
+ std::array<Value, 4> lhsDigits = {
+ getLowDigit(lhsElem0), getHighDigit(lhsElem0), getLowDigit(lhsElem1),
+ getHighDigit(lhsElem1)};
+ std::array<Value, 4> rhsDigits = {
+ getLowDigit(rhsElem0), getHighDigit(rhsElem0), getLowDigit(rhsElem1),
+ getHighDigit(rhsElem1)};
+
+ for (unsigned i = 0, e = lhsDigits.size(); i != e; ++i) {
+ for (unsigned j = 0; i + j != e; ++j) {
+ Value mul =
+ rewriter.create<arith::MulIOp>(loc, lhsDigits[i], rhsDigits[j]);
+ Value current =
+ rewriter.createOrFold<arith::AddIOp>(loc, resultDigits[i + j], mul);
+ resultDigits[i + j] = getLowDigit(current);
+ if (i + j + 1 != e) {
+ Value carry = rewriter.createOrFold<arith::AddIOp>(
+ loc, resultDigits[i + j + 1], getHighDigit(current));
+ resultDigits[i + j + 1] = carry;
+ }
+ }
+ }
+
+ auto combineDigits = [shiftVal, loc, &rewriter](Value low, Value high) {
+ Value highBits = rewriter.create<arith::ShLIOp>(loc, high, shiftVal);
+ return rewriter.create<arith::OrIOp>(loc, low, highBits);
+ };
+ Value resultElem0 = combineDigits(resultDigits[0], resultDigits[1]);
+ Value resultElem1 = combineDigits(resultDigits[2], resultDigits[3]);
+ Value resultVec =
+ constructResultVector(rewriter, loc, newTy, {resultElem0, resultElem1});
+ rewriter.replaceOp(op, resultVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// ConvertExtSI
//===----------------------------------------------------------------------===//
Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
Value extended = rewriter.createOrFold<arith::ExtSIOp>(
loc, newResultComponentTy, newOperand);
- Value operandZeroCst = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(newResultComponentTy));
+ Value operandZeroCst =
+ createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
Value signBit = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
Value signValue =
Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
Value extended = rewriter.createOrFold<arith::ExtUIOp>(
loc, newResultComponentTy, newOperand);
- Value zeroCst = rewriter.create<arith::ConstantOp>(
- op->getLoc(), rewriter.getZeroAttr(newTy));
+ Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
rewriter.replaceOp(op, newRes);
return success();
using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
void runOnOperation() override {
- if (!llvm::isPowerOf2_32(widestIntSupported)) {
+ if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
signalPassFailure();
return;
}
unsigned widestIntSupportedByTarget)
: maxIntWidth(widestIntSupportedByTarget) {
assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
- "Only power-of-two integers are supported");
+ "Only power-of-two integers with are supported");
+ assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
// Scalar case.
addConversion([this](IntegerType ty) -> Optional<Type> {
// Misc ops.
ConvertConstant, ConvertVectorPrint,
// Binary ops.
- ConvertAddI,
+ ConvertAddI, ConvertMulI,
// Extension and truncation ops.
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
patterns.getContext());
%b = arith.trunci %a : vector<3xi64> to vector<3xi16>
return %b : vector<3xi16>
}
+
+// CHECK-LABEL: func.func @muli_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : vector<2xi32>
+// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : vector<2xi32>
+// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : vector<2xi32>
+//
+// CHECK-DAG: [[MASK:%.+]] = arith.constant 65535 : i32
+// CHECK-DAG: [[C16:%.+]] = arith.constant 16 : i32
+//
+// CHECK: [[LOWLOW0:%.+]] = arith.andi [[LOW0]], [[MASK]] : i32
+// CHECK-NEXT: [[HIGHLOW0:%.+]] = arith.shrui [[LOW0]], [[C16]] : i32
+// CHECK-NEXT: [[LOWHIGH0:%.+]] = arith.andi [[HIGH0]], [[MASK]] : i32
+// CHECK-NEXT: [[HIGHHIGH0:%.+]] = arith.shrui [[HIGH0]], [[C16]] : i32
+// CHECK-NEXT: [[LOWLOW1:%.+]] = arith.andi [[LOW1]], [[MASK]] : i32
+// CHECK-NEXT: [[HIGHLOW1:%.+]] = arith.shrui [[LOW1]], [[C16]] : i32
+// CHECK-NEXT: [[LOWHIGH1:%.+]] = arith.andi [[HIGH1]], [[MASK]] : i32
+// CHECK-NEXT: [[HIGHHIGH1:%.+]] = arith.shrui [[HIGH1]], [[C16]] : i32
+//
+// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWLOW1]] : i32
+// CHECK-DAG {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHLOW1]] : i32
+// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[LOWHIGH1]] : i32
+// CHECK-DAG: {{%.+}} = arith.muli [[LOWLOW0]], [[HIGHHIGH1]] : i32
+//
+// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWLOW1]] : i32
+// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[HIGHLOW1]] : i32
+// CHECK-DAG: {{%.+}} = arith.muli [[HIGHLOW0]], [[LOWHIGH1]] : i32
+//
+// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[LOWLOW1]] : i32
+// CHECK-DAG: {{%.+}} = arith.muli [[LOWHIGH0]], [[HIGHLOW1]] : i32
+//
+// CHECK-DAG: {{%.+}} = arith.muli [[HIGHHIGH0]], [[LOWLOW1]] : i32
+//
+// CHECK: [[RESHIGH0:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
+// CHECK-NEXT: [[RES0:%.+]] = arith.ori {{%.+}}, [[RESHIGH0]] : i32
+// CHECK-NEXT: [[RESHIGH1:%.+]] = arith.shli {{%.+}}, [[C16]] : i32
+// CHECK-NEXT: [[RES1:%.+]] = arith.ori {{%.+}}, [[RESHIGH1]] : i32
+// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[RES0]], [[VZ]] [0] : i32 into vector<2xi32>
+// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[RES1]], [[INS0]] [1] : i32 into vector<2xi32>
+// CHECK-NEXT: return [[INS1]] : vector<2xi32>
+func.func @muli_scalar(%a : i64, %b : i64) -> i64 {
+ %m = arith.muli %a, %b : i64
+ return %m : i64
+}
+
+// CHECK-LABEL: func.func @muli_vector
+// CHECK-SAME: ({{%.+}}: vector<3x2xi32>, {{%.+}}: vector<3x2xi32>) -> vector<3x2xi32>
+// CHECK: return {{%.+}} : vector<3x2xi32>
+func.func @muli_vector(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi64> {
+ %m = arith.muli %a, %b : vector<3xi64>
+ return %m : vector<3xi64>
+}
--- /dev/null
+// Check that the wide integer multiplication emulation produces the same result as wide
+// multiplication. Emulate i16 ops with i8 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines --check-prefix=WIDE
+
+// RUN: mlir-opt %s --arith-emulate-wide-int="widest-int-supported=8" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --match-full-lines --check-prefix=EMULATED
+
+func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
+ %res = arith.muli %lhs, %rhs : i16
+ vector.print %res : i16
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0 : i16
+ %cst1 = arith.constant 1 : i16
+ %cst_1 = arith.constant -1 : i16
+ %cst_3 = arith.constant -3 : i16
+
+ %cst13 = arith.constant 13 : i16
+ %cst37 = arith.constant 37 : i16
+ %cst42 = arith.constant 42 : i16
+
+ %cst256 = arith.constant 256 : i16
+ %cst_i16_max = arith.constant 32767 : i16
+ %cst_i16_min = arith.constant -32768 : i16
+
+ // WIDE: 0
+ // EMULATED: ( 0, 0 )
+ func.call @check_muli(%cst0, %cst0) : (i16, i16) -> ()
+ // WIDE-NEXT: 0
+ // EMULATED-NEXT: ( 0, 0 )
+ func.call @check_muli(%cst0, %cst1) : (i16, i16) -> ()
+ // WIDE-NEXT: 1
+ // EMULATED-NEXT: ( 1, 0 )
+ func.call @check_muli(%cst1, %cst1) : (i16, i16) -> ()
+ // WIDE-NEXT: -1
+ // EMULATED-NEXT: ( -1, -1 )
+ func.call @check_muli(%cst1, %cst_1) : (i16, i16) -> ()
+ // WIDE-NEXT: 1
+ // EMULATED-NEXT: ( 1, 0 )
+ func.call @check_muli(%cst_1, %cst_1) : (i16, i16) -> ()
+ // WIDE-NEXT: -3
+ // EMULATED-NEXT: ( -3, -1 )
+ func.call @check_muli(%cst1, %cst_3) : (i16, i16) -> ()
+
+ // WIDE-NEXT: 169
+ // EMULATED-NEXT: ( -87, 0 )
+ func.call @check_muli(%cst13, %cst13) : (i16, i16) -> ()
+ // WIDE-NEXT: 481
+ // EMULATED-NEXT: ( -31, 1 )
+ func.call @check_muli(%cst13, %cst37) : (i16, i16) -> ()
+ // WIDE-NEXT: 1554
+ // EMULATED-NEXT: ( 18, 6 )
+ func.call @check_muli(%cst37, %cst42) : (i16, i16) -> ()
+
+ // WIDE-NEXT: -256
+ // EMULATED-NEXT: ( 0, -1 )
+ func.call @check_muli(%cst_1, %cst256) : (i16, i16) -> ()
+ // WIDE-NEXT: 3328
+ // EMULATED-NEXT: ( 0, 13 )
+ func.call @check_muli(%cst256, %cst13) : (i16, i16) -> ()
+ // WIDE-NEXT: 9472
+ // EMULATED-NEXT: ( 0, 37 )
+ func.call @check_muli(%cst256, %cst37) : (i16, i16) -> ()
+ // WIDE-NEXT: -768
+ // EMULATED-NEXT: ( 0, -3 )
+ func.call @check_muli(%cst256, %cst_3) : (i16, i16) -> ()
+
+ // WIDE-NEXT: 32755
+ // EMULATED-NEXT: ( -13, 127 )
+ func.call @check_muli(%cst13, %cst_i16_max) : (i16, i16) -> ()
+ // WIDE-NEXT: -32768
+ // EMULATED-NEXT: ( 0, -128 )
+ func.call @check_muli(%cst_i16_min, %cst37) : (i16, i16) -> ()
+
+ // WIDE-NEXT: 1
+ // EMULATED-NEXT: ( 1, 0 )
+ func.call @check_muli(%cst_i16_max, %cst_i16_max) : (i16, i16) -> ()
+ // WIDE-NEXT: -32768
+ // EMULATED-NEXT: ( 0, -128 )
+ func.call @check_muli(%cst_i16_min, %cst13) : (i16, i16) -> ()
+ // WIDE-NEXT: 0
+ // EMULATED-NEXT: ( 0, 0 )
+ func.call @check_muli(%cst_i16_min, %cst_i16_min) : (i16, i16) -> ()
+
+ return
+}