From: liqinweng Date: Fri, 30 Dec 2022 03:54:06 +0000 (+0800) Subject: [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType X-Git-Tag: upstream/17.0.6~22458 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f8f4fc11d1b9256685a44725b55e431c57aaade3;p=platform%2Fupstream%2Fllvm.git [MLIR][Arith][NFC] Use the interface of 'getElementTypeOrSelf' to get the resType Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D140608 --- diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index f6446ea..e61169c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1176,12 +1176,9 @@ OpFoldResult arith::ExtUIOp::fold(ArrayRef operands) { getInMutable().assign(lhs.getIn()); return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); + + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.zext(bitWidth); @@ -1205,12 +1202,9 @@ OpFoldResult arith::ExtSIOp::fold(ArrayRef operands) { getInMutable().assign(lhs.getIn()); return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); + + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.sext(bitWidth); @@ -1259,13 +1253,8 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); - + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.trunc(bitWidth); @@ -1361,12 +1350,7 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); @@ -1387,12 +1371,7 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); @@ -1412,17 +1391,12 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { - IntegerType intTy = resEleType.cast(); + operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/true); + APSInt api(bitWidth, /*isUnsigned=*/true); castStatus = APFloat::opInvalidOp != a.convertToInteger(api, APFloat::rmTowardZero, &ignored); return api; @@ -1438,17 +1412,12 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { - IntegerType intTy = resEleType.cast(); + operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/false); + APSInt api(bitWidth, /*isUnsigned=*/false); castStatus = APFloat::opInvalidOp != a.convertToInteger(api, APFloat::rmTowardZero, &ignored); return api;