[MLIR][Arith] Fold trunci with ext if the bit width of the input type of ext is great...
authorliqinweng <Liqin.Weng@streamcomputing.com>
Thu, 6 Apr 2023 13:08:34 +0000 (21:08 +0800)
committerliqinweng <Liqin.Weng@streamcomputing.com>
Thu, 6 Apr 2023 13:08:39 +0000 (21:08 +0800)
This patch is mainly to deal with folding trunci with ext,as flows:
    trunci(zexti(a)) -> trunci(a)
    trunci(zexti(a)) -> trunci(a)

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140604

mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir

index d7ce71a..e203dbc 100644 (file)
@@ -1245,11 +1245,22 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
-  // trunci(zexti(a)) -> a
-  // trunci(sexti(a)) -> a
   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
-      matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
-    return getOperand().getDefiningOp()->getOperand(0);
+      matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
+    Value src = getOperand().getDefiningOp()->getOperand(0);
+    Type srcType = getElementTypeOrSelf(src.getType());
+    Type dstType = getElementTypeOrSelf(getType());
+    // trunci(zexti(a)) -> trunci(a)
+    // trunci(sexti(a)) -> trunci(a)
+    if (srcType.cast<IntegerType>().getWidth() >
+        dstType.cast<IntegerType>().getWidth()) {
+      setOperand(src);
+      return getResult();
+    }
+    // trunci(zexti(a)) -> a
+    // trunci(sexti(a)) -> a
+    return src;
+  }
 
   // trunci(trunci(a)) -> trunci(a))
   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
index 0170620..1f96876 100644 (file)
@@ -619,6 +619,26 @@ func.func @truncExtui(%arg0: i32) -> i32 {
   return %trunci : i32
 }
 
+// CHECK-LABEL: @truncExtui2
+//       CHECK:  %[[ARG0:.+]]: i32
+//       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtui2(%arg0: i32) -> i16 {
+  %extui = arith.extui %arg0 : i32 to i64
+  %trunci = arith.trunci %extui : i64 to i16
+  return %trunci : i16
+}
+
+// CHECK-LABEL: @truncExtuiVector
+//       CHECK:  %[[ARG0:.+]]: vector<2xi32>
+//       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtuiVector(%arg0: vector<2xi32>) -> vector<2xi16> {
+  %extsi = arith.extui %arg0 : vector<2xi32> to vector<2xi64>
+  %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16>
+  return %trunci : vector<2xi16>
+}
+
 // CHECK-LABEL: @truncExtsi
 //       CHECK-NOT:  trunci
 //       CHECK:   return  %arg0
@@ -628,6 +648,26 @@ func.func @truncExtsi(%arg0: i32) -> i32 {
   return %trunci : i32
 }
 
+// CHECK-LABEL: @truncExtsi2
+//       CHECK:  %[[ARG0:.+]]: i32
+//       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtsi2(%arg0: i32) -> i16 {
+  %extsi = arith.extsi %arg0 : i32 to i64
+  %trunci = arith.trunci %extsi : i64 to i16
+  return %trunci : i16
+}
+
+// CHECK-LABEL: @truncExtsiVector
+//       CHECK:  %[[ARG0:.+]]: vector<2xi32>
+//       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtsiVector(%arg0: vector<2xi32>) -> vector<2xi16> {
+  %extsi = arith.extsi %arg0 : vector<2xi32> to vector<2xi64>
+  %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16>
+  return %trunci : vector<2xi16>
+}
+
 // CHECK-LABEL: @truncConstantSplat
 //       CHECK:   %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8>
 //       CHECK:   return %[[cres]]