From 834c17f618ce87b14446e42250d924b8d5f01abe Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 13 Dec 2022 10:49:14 -0500 Subject: [PATCH] [mlir][arith] Add canonicalization patterns for 'mul*i_extended' - Add a fold for `mulsi_extended(x, 1)` - Add folds to demote wide integer multiplication to `mul*i_extended` when the result is shifted and truncated: `trunci(shrui(mul(*ext(x), *ext(y)), c)) -> mul*i_extended(x, y)` Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D139778 --- mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 1 + mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 55 +++++++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 33 +++- mlir/test/Dialect/Arith/canonicalize.mlir | 169 +++++++++++++++++++++ 4 files changed, 257 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 594ba46..ce61890 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1071,6 +1071,7 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> { }]; let hasFolder = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index cf2a767..bd891b9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -122,6 +122,16 @@ def MulSIExtendedToMulI : [(Arith_MulIOp $x, $y), (replaceWithValue $x)], [(Constraint> $res__1)]>; +// mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)] +def MulSIExtendedRHSOne : + Pattern<(Arith_MulSIExtendedOp $x, (Arith_ConstantOp $c1)), + [(replaceWithValue $x), + (Arith_ExtSIOp(Arith_CmpIOp + (NativeCodeCall<"arith::CmpIPredicate::slt">), + $x, + (Arith_ConstantOp (GetZeroAttr $x))))], + [(Constraint> $c1)]>; + //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// @@ -244,6 +254,51 @@ def OrOfExtSI : [(Constraint> $x, $y)]>; //===----------------------------------------------------------------------===// +// TruncIOp +//===----------------------------------------------------------------------===// + +def ValuesWithSameType : + Constraint< + CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>; + +def ValueWiderThan : + Constraint< + CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">>; + +def TruncationMatchesShiftAmount : + Constraint< + CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == " + "getIntOrSplatIntValue($2)">>; + +// trunci(shrsi(x, c)) -> trunci(shrui(x, c)) +def TruncIShrSIToTrunciShrUI : + Pat<(Arith_TruncIOp:$tr (Arith_ShRSIOp $x, (Arith_ConstantOp $c0))), + (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))), + [(TruncationMatchesShiftAmount $x, $tr, $c0)]>; + +// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) +def TruncIShrUIMulIToMulSIExtended : + Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp + (Arith_MulIOp:$mul + (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), + (Arith_ConstantOp $c0))), + (Arith_MulSIExtendedOp:$res__1 $x, $y), + [(ValuesWithSameType $tr, $x, $y), + (ValueWiderThan $mul, $x), + (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + +// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) +def TruncIShrUIMulIToMulUIExtended : + Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp + (Arith_MulIOp:$mul + (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)), + (Arith_ConstantOp $c0))), + (Arith_MulUIExtendedOp:$res__1 $x, $y), + [(ValuesWithSameType $tr, $x, $y), + (ValueWiderThan $mul, $x), + (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + +//===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 25a3dd4..c09926a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -74,6 +74,31 @@ static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { invertPredicate(pred.getValue())); } +static int64_t getScalarOrElementWidth(Type type) { + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + + if (auto shapeTy = type.dyn_cast()) + return shapeTy.getElementTypeBitWidth(); + + return -1; +} + +static int64_t getScalarOrElementWidth(Value value) { + return getScalarOrElementWidth(value.getType()); +} + +static int64_t getIntOrSplatIntValue(Attribute attr) { + if (auto intAttr = attr.dyn_cast()) + return intAttr.getInt(); + + if (auto splatAttr = attr.dyn_cast()) + if (splatAttr.getElementType().isa()) + return splatAttr.getSplatValue().getLimitedValue(); + + return -1; +} + //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// @@ -393,7 +418,7 @@ arith::MulSIExtendedOp::fold(ArrayRef operands, void arith::MulSIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// @@ -1249,6 +1274,12 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { return checkWidthChangeCast(inputs, outputs); } +void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + LogicalResult arith::TruncIOp::verify() { return verifyTruncateOp(*this); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 644af88e..a4c800c 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -761,6 +761,30 @@ func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) { return %low, %high : i32, i32 } +// CHECK-LABEL: @mulsiExtendedOneRhs +// CHECK-SAME: (%[[ARG:.+]]: i32) -> (i32, i32) +// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : i32 +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[CMP]] : i1 to i32 +// CHECK-NEXT: return %[[ARG]], %[[EXT]] : i32, i32 +func.func @mulsiExtendedOneRhs(%arg0: i32) -> (i32, i32) { + %one = arith.constant 1 : i32 + %low, %high = arith.mulsi_extended %arg0, %one: i32 + return %low, %high : i32, i32 +} + +// CHECK-LABEL: @mulsiExtendedOneRhsSplat +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<3xi32> +// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : vector<3xi32> +// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[CMP]] : vector<3xi1> to vector<3xi32> +// CHECK-NEXT: return %[[ARG]], %[[EXT]] : vector<3xi32>, vector<3xi32> +func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) { + %one = arith.constant dense<1> : vector<3xi32> + %low, %high = arith.mulsi_extended %arg0, %one: vector<3xi32> + return %low, %high : vector<3xi32>, vector<3xi32> +} + // CHECK-LABEL: @mulsiExtendedUnusedHigh // CHECK-SAME: (%[[ARG:.+]]: i32) -> i32 // CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32 @@ -1916,3 +1940,148 @@ func.func @andand3(%a : i32, %b : i32) -> i32 { %res = arith.andi %c, %b : i32 return %res : i32 } + +// ----- + +// CHECK-LABEL: @truncIShrSIToTrunciShrUI +// CHECK-SAME: (%[[A:.+]]: i64) +// CHECK-NEXT: %[[C32:.+]] = arith.constant 32 : i64 +// CHECK-NEXT: %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] : i64 +// CHECK-NEXT: %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32 +// CHECK-NEXT: return %[[TRU]] : i32 +func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 { + %c32 = arith.constant 32: i64 + %sh = arith.shrsi %a, %c32 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1 +// CHECK: arith.shrsi +func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 { + %c33 = arith.constant 33: i64 + %sh = arith.shrsi %a, %c33 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt2 +// CHECK: arith.shrsi +func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 { + %c31 = arith.constant 31: i64 + %sh = arith.shrsi %a, %c31 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulSIExtended +// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32) +// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32 +// CHECK-NEXT: return %[[HIGH]] : i32 +func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i64 + %y = arith.extsi %b: i32 to i64 + %m = arith.muli %x, %y: i64 + %c32 = arith.constant 32: i64 + %sh = arith.shrui %m, %c32 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulSIExtendedVector +// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) +// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32> +// CHECK-NEXT: return %[[HIGH]] : vector<3xi32> +func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> { + %x = arith.extsi %a: vector<3xi32> to vector<3xi64> + %y = arith.extsi %b: vector<3xi32> to vector<3xi64> + %m = arith.muli %x, %y: vector<3xi64> + %c32 = arith.constant dense<32>: vector<3xi64> + %sh = arith.shrui %m, %c32 : vector<3xi64> + %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32> + return %hi : vector<3xi32> +} + +// CHECK-LABEL: @wideMulToMulUIExtended +// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32) +// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32 +// CHECK-NEXT: return %[[HIGH]] : i32 +func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 { + %x = arith.extui %a: i32 to i64 + %y = arith.extui %b: i32 to i64 + %m = arith.muli %x, %y: i64 + %c32 = arith.constant 32: i64 + %sh = arith.shrui %m, %c32 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulUIExtendedVector +// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) +// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32> +// CHECK-NEXT: return %[[HIGH]] : vector<3xi32> +func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> { + %x = arith.extui %a: vector<3xi32> to vector<3xi64> + %y = arith.extui %b: vector<3xi32> to vector<3xi64> + %m = arith.muli %x, %y: vector<3xi64> + %c32 = arith.constant dense<32>: vector<3xi64> + %sh = arith.shrui %m, %c32 : vector<3xi64> + %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32> + return %hi : vector<3xi32> +} + +// CHECK-LABEL: @wideMulToMulIExtendedMixedExt +// CHECK: arith.muli +// CHECK: arith.shrui +// CHECK: arith.trunci +func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i64 + %y = arith.extui %b: i32 to i64 + %m = arith.muli %x, %y: i64 + %c32 = arith.constant 32: i64 + %sh = arith.shrui %m, %c32 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulSIExtendedBadExt +// CHECK: arith.muli +// CHECK: arith.shrui +// CHECK: arith.trunci +func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 { + %x = arith.extsi %a: i16 to i64 + %y = arith.extsi %b: i16 to i64 + %m = arith.muli %x, %y: i64 + %c32 = arith.constant 32: i64 + %sh = arith.shrui %m, %c32 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1 +// CHECK: arith.muli +// CHECK: arith.shrui +// CHECK: arith.trunci +func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i64 + %y = arith.extsi %b: i32 to i64 + %m = arith.muli %x, %y: i64 + %c33 = arith.constant 33: i64 + %sh = arith.shrui %m, %c33 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} + +// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2 +// CHECK: arith.muli +// CHECK: arith.shrui +// CHECK: arith.trunci +func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i64 + %y = arith.extsi %b: i32 to i64 + %m = arith.muli %x, %y: i64 + %c31 = arith.constant 31: i64 + %sh = arith.shrui %m, %c31 : i64 + %hi = arith.trunci %sh: i64 to i32 + return %hi : i32 +} -- 2.7.4