From: William S. Moses Date: Sun, 2 May 2021 04:39:45 +0000 (-0400) Subject: [MLIR] Canonicalization of Integer Cast Operations X-Git-Tag: llvmorg-14-init~7848 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=78720296f3912b4095eafa5c1277646fd67bae99;p=platform%2Fupstream%2Fllvm.git [MLIR] Canonicalization of Integer Cast Operations 1) Canonicalize IndexCast(SExt(x)) => IndexCast(x) 2) Provide constant folds of sign_extend and truncate Differential Revision: https://reviews.llvm.org/D101714 --- diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index aaa5a8a..99b8889 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1249,6 +1249,7 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> { }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1711,6 +1712,7 @@ def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 9260e27..fc73947 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1241,6 +1241,29 @@ OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { return {}; } +namespace { +/// index_cast(sign_extend x) => index_cast(x) +struct IndexCastOfSExt : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexCastOp op, + PatternRewriter &rewriter) const override { + + if (auto extop = op.getOperand().getDefiningOp()) { + op.setOperand(extop.getOperand()); + return success(); + } + return failure(); + } +}; + +} // namespace + +void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -1439,6 +1462,20 @@ static LogicalResult verify(SignExtendIOp op) { return success(); } +OpFoldResult SignExtendIOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary operation takes one operand"); + + if (!operands[0]) + return {}; + + if (auto lhs = operands[0].dyn_cast()) { + return IntegerAttr::get( + getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); + } + + return {}; +} + //===----------------------------------------------------------------------===// // SignedDivIOp //===----------------------------------------------------------------------===// @@ -2686,7 +2723,18 @@ OpFoldResult TruncateIOp::fold(ArrayRef operands) { matchPattern(getOperand(), m_Op())) return getOperand().getDefiningOp()->getOperand(0); - return nullptr; + assert(operands.size() == 1 && "unary operation takes one operand"); + + if (!operands[0]) + return {}; + + if (auto lhs = operands[0].dyn_cast()) { + + return IntegerAttr::get( + getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 2e8a599..b4942de 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -698,15 +698,13 @@ func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) { // Check sign and zero extension and truncation of integers. // CHECK-LABEL: @integer_extension_and_truncation -func @integer_extension_and_truncation() { -// CHECK-NEXT: %0 = llvm.mlir.constant(-3 : i3) : i3 - %0 = constant 5 : i3 -// CHECK-NEXT: = llvm.sext %0 : i3 to i6 - %1 = sexti %0 : i3 to i6 -// CHECK-NEXT: = llvm.zext %0 : i3 to i6 - %2 = zexti %0 : i3 to i6 -// CHECK-NEXT: = llvm.trunc %0 : i3 to i2 - %3 = trunci %0 : i3 to i2 +func @integer_extension_and_truncation(%arg0 : i3) { +// CHECK-NEXT: = llvm.sext %arg0 : i3 to i6 + %0 = sexti %arg0 : i3 to i6 +// CHECK-NEXT: = llvm.zext %arg0 : i3 to i6 + %1 = zexti %arg0 : i3 to i6 +// CHECK-NEXT: = llvm.trunc %arg0 : i3 to i2 + %2 = trunci %arg0 : i3 to i2 return } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index e2b5e7b..8db3065 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -399,3 +399,32 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 { %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @indexCastOfSignExtend +// CHECK: %[[res:.+]] = index_cast %arg0 : i8 to index +// CHECK: return %[[res]] +func @indexCastOfSignExtend(%arg0: i8) -> index { + %ext = sexti %arg0 : i8 to i16 + %idx = index_cast %ext : i16 to index + return %idx : index +} + +// CHECK-LABEL: @signExtendConstant +// CHECK: %[[cres:.+]] = constant -2 : i16 +// CHECK: return %[[cres]] +func @signExtendConstant() -> i16 { + %c-2 = constant -2 : i8 + %ext = sexti %c-2 : i8 to i16 + return %ext : i16 +} + +// CHECK-LABEL: @truncConstant +// CHECK: %[[cres:.+]] = constant -2 : i16 +// CHECK: return %[[cres]] +func @truncConstant(%arg0: i8) -> i16 { + %c-2 = constant -2 : i32 + %tr = trunci %c-2 : i32 to i16 + return %tr : i16 +}