From dd39f9b418379264ceb6a232dc0b2a5fb18a4203 Mon Sep 17 00:00:00 2001 From: liqinweng Date: Thu, 6 Apr 2023 21:08:34 +0800 Subject: [PATCH] [MLIR][Arith] Fold trunci with ext if the bit width of the input type of ext is greater than the MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This patch is mainly to deal with folding trunci with ext,as flows: trunci(zexti(a)) -> trunci(a) trunci(zexti(a)) -> trunci(a) Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D140604 --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 +++++++++++---- mlir/test/Dialect/Arith/canonicalize.mlir | 40 +++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d7ce71a..e203dbc 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1245,11 +1245,22 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } //===----------------------------------------------------------------------===// OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { - // trunci(zexti(a)) -> a - // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || - matchPattern(getOperand(), m_Op())) - return getOperand().getDefiningOp()->getOperand(0); + matchPattern(getOperand(), m_Op())) { + Value src = getOperand().getDefiningOp()->getOperand(0); + Type srcType = getElementTypeOrSelf(src.getType()); + Type dstType = getElementTypeOrSelf(getType()); + // trunci(zexti(a)) -> trunci(a) + // trunci(sexti(a)) -> trunci(a) + if (srcType.cast().getWidth() > + dstType.cast().getWidth()) { + setOperand(src); + return getResult(); + } + // trunci(zexti(a)) -> a + // trunci(sexti(a)) -> a + return src; + } // trunci(trunci(a)) -> trunci(a)) if (matchPattern(getOperand(), m_Op())) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 0170620..1f96876 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -619,6 +619,26 @@ func.func @truncExtui(%arg0: i32) -> i32 { return %trunci : i32 } +// CHECK-LABEL: @truncExtui2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtui2(%arg0: i32) -> i16 { + %extui = arith.extui %arg0 : i32 to i64 + %trunci = arith.trunci %extui : i64 to i16 + return %trunci : i16 +} + +// CHECK-LABEL: @truncExtuiVector +// CHECK: %[[ARG0:.+]]: vector<2xi32> +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> +// CHECK: return %[[CST:.*]] +func.func @truncExtuiVector(%arg0: vector<2xi32>) -> vector<2xi16> { + %extsi = arith.extui %arg0 : vector<2xi32> to vector<2xi64> + %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16> + return %trunci : vector<2xi16> +} + // CHECK-LABEL: @truncExtsi // CHECK-NOT: trunci // CHECK: return %arg0 @@ -628,6 +648,26 @@ func.func @truncExtsi(%arg0: i32) -> i32 { return %trunci : i32 } +// CHECK-LABEL: @truncExtsi2 +// CHECK: %[[ARG0:.+]]: i32 +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16 +// CHECK: return %[[CST:.*]] +func.func @truncExtsi2(%arg0: i32) -> i16 { + %extsi = arith.extsi %arg0 : i32 to i64 + %trunci = arith.trunci %extsi : i64 to i16 + return %trunci : i16 +} + +// CHECK-LABEL: @truncExtsiVector +// CHECK: %[[ARG0:.+]]: vector<2xi32> +// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> +// CHECK: return %[[CST:.*]] +func.func @truncExtsiVector(%arg0: vector<2xi32>) -> vector<2xi16> { + %extsi = arith.extsi %arg0 : vector<2xi32> to vector<2xi64> + %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16> + return %trunci : vector<2xi16> +} + // CHECK-LABEL: @truncConstantSplat // CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> // CHECK: return %[[cres]] -- 2.7.4