From: Jakub Kuderski Date: Thu, 27 Apr 2023 15:13:46 +0000 (-0400) Subject: [mlir][arith] Add missing canon pattern `trunci(ext*i(x)) -> ext*i(x)` X-Git-Tag: upstream/17.0.6~10178 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ab85aec1affc92647c195f736d1bac69976baeb8;p=platform%2Fupstream%2Fllvm.git [mlir][arith] Add missing canon pattern `trunci(ext*i(x)) -> ext*i(x)` This pattern triggers when only the extension bits are truncated. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D149286 --- diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index d4c6b81..ba1f3f8 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -319,6 +319,20 @@ def TruncationMatchesShiftAmount : CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == " "*getIntOrSplatIntValue($2)">]>>; +// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated +def TruncIExtSIToExtSI : + Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)), + (Arith_ExtSIOp $x), + [(ValueWiderThan $ext, $tr), + (ValueWiderThan $tr, $x)]>; + +// trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated +def TruncIExtUIToExtUI : + Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)), + (Arith_ExtUIOp $x), + [(ValueWiderThan $ext, $tr), + (ValueWiderThan $tr, $x)]>; + // trunci(shrsi(x, c)) -> trunci(shrui(x, c)) def TruncIShrSIToTrunciShrUI : Pat<(Arith_TruncIOp:$tr diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 446bb64..b4b0572 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1290,8 +1290,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add( + context); } LogicalResult arith::TruncIOp::verify() { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 1f96876..14589b2 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -629,6 +629,16 @@ func.func @truncExtui2(%arg0: i32) -> i16 { return %trunci : i16 } +// CHECK-LABEL: @truncExtui3 +// CHECK: %[[ARG0:.+]]: i8 +// CHECK: %[[CST:.*]] = arith.extui %[[ARG0:.+]] : i8 to i16 +// CHECK: return %[[CST:.*]] : i16 +func.func @truncExtui3(%arg0: i8) -> i16 { + %extui = arith.extui %arg0 : i8 to i32 + %trunci = arith.trunci %extui : i32 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncExtuiVector // CHECK: %[[ARG0:.+]]: vector<2xi32> // CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> @@ -658,6 +668,16 @@ func.func @truncExtsi2(%arg0: i32) -> i16 { return %trunci : i16 } +// CHECK-LABEL: @truncExtsi3 +// CHECK: %[[ARG0:.+]]: i8 +// CHECK: %[[CST:.*]] = arith.extsi %[[ARG0:.+]] : i8 to i16 +// CHECK: return %[[CST:.*]] : i16 +func.func @truncExtsi3(%arg0: i8) -> i16 { + %extsi = arith.extsi %arg0 : i8 to i32 + %trunci = arith.trunci %extsi : i32 to i16 + return %trunci : i16 +} + // CHECK-LABEL: @truncExtsiVector // CHECK: %[[ARG0:.+]]: vector<2xi32> // CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16> diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 5cc0eb5..47a19bb 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1107,14 +1107,10 @@ func.func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi1> attributes {} // ----- -// TODO Canonicalize this into: -// arith.extui %arg0 : i1 to i2 - -// CHECK-LABEL: func @do_not_fold_trunci +// CHECK-LABEL: func @fold_trunci // CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: i1) -func.func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} { - // CHECK-NEXT: arith.extui %[[ARG0]] : i1 to i8 - // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.trunci %{{.*}} : i8 to i2 +func.func @fold_trunci(%arg0: i1) -> i2 attributes {} { + // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.extui %[[ARG0]] : i1 to i2 // CHECK-NEXT: return %[[RES]] : i2 %0 = arith.extui %arg0 : i1 to i8 %1 = arith.trunci %0 : i8 to i2 @@ -1123,11 +1119,10 @@ func.func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} { // ----- -// CHECK-LABEL: func @do_not_fold_trunci_vector +// CHECK-LABEL: func @fold_trunci_vector // CHECK-SAME: (%[[ARG0:[0-9a-z]*]]: vector<4xi1>) -func.func @do_not_fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} { - // CHECK-NEXT: arith.extui %[[ARG0]] : vector<4xi1> to vector<4xi8> - // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.trunci %{{.*}} : vector<4xi8> to vector<4xi2> +func.func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} { + // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.extui %[[ARG0]] : vector<4xi1> to vector<4xi2> // CHECK-NEXT: return %[[RES]] : vector<4xi2> %0 = arith.extui %arg0 : vector<4xi1> to vector<4xi8> %1 = arith.trunci %0 : vector<4xi8> to vector<4xi2>